Pytorch学习(六)——

6.1 自定义损失函数

PyTorch在torch.nn模块为我们提供了许多常用的损失函数,比如:MSELoss,L1Loss,BCELoss…… 但是随着深度学习的发展,出现了越来越多的非官方提供的Loss,比如DiceLoss,HuberLoss,SobolevLoss…… 这些Loss Function专门针对一些非通用的模型,PyTorch不能将他们全部添加到库中去,因此这些损失函数的实现则需要我们通过自定义损失函数来实现。另外,在科学研究中,我们往往会提出全新的损失函数来提升模型的表现,这时我们既无法使用PyTorch自带的损失函数,也没有相关的博客供参考,此时自己实现损失函数就显得更为重要了。

6.1.1 以函数方式定义

事实上,损失函数仅仅是一个函数而已,因此我们可以通过直接以函数定义的方式定义一个自己的函数,如下所示:

def my_loss(output, target):
    loss = torch.mean((output - target)**2)
    return loss
复制代码

6.1.2 以类方式定义

虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss, 部分继承自_WeightedLoss, 而_WeightedLoss继承自_loss _loss继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类,在下面的例子中我们以DiceLoss为例向大家讲述。

Dice Loss是一种在分割领域常见的损失函数,定义如下:

DSC=2XYX+YDSC = \frac{2|X∩Y|}{|X|+|Y|}

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享