6.1 自定义损失函数
PyTorch在torch.nn模块为我们提供了许多常用的损失函数,比如:MSELoss,L1Loss,BCELoss…… 但是随着深度学习的发展,出现了越来越多的非官方提供的Loss,比如DiceLoss,HuberLoss,SobolevLoss…… 这些Loss Function专门针对一些非通用的模型,PyTorch不能将他们全部添加到库中去,因此这些损失函数的实现则需要我们通过自定义损失函数来实现。另外,在科学研究中,我们往往会提出全新的损失函数来提升模型的表现,这时我们既无法使用PyTorch自带的损失函数,也没有相关的博客供参考,此时自己实现损失函数就显得更为重要了。
6.1.1 以函数方式定义
事实上,损失函数仅仅是一个函数而已,因此我们可以通过直接以函数定义的方式定义一个自己的函数,如下所示:
def my_loss(output, target):
loss = torch.mean((output - target)**2)
return loss
复制代码
6.1.2 以类方式定义
虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss
函数部分继承自_loss
, 部分继承自_WeightedLoss
, 而_WeightedLoss
继承自_loss
, _loss
继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类,在下面的例子中我们以DiceLoss为例向大家讲述。
Dice Loss是一种在分割领域常见的损失函数,定义如下:
© 版权声明
文章版权归作者所有,未经允许请勿转载。
THE END
喜欢就支持一下吧
相关推荐