Source code for thelper.optim.losses

import torch
import torch.nn as nn


[docs]class FocalLoss(nn.Module): """ .. note:: Contributed by Mario Beaulieu <mario.beaulieu@crim.ca>. .. seealso:: | `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_, *Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár*, arXiv article. """
[docs] def __init__(self, gamma=2, alpha=0.5, weight=None, ignore_index=255): super().__init__() self.gamma = gamma self.alpha = alpha self.weight = weight self.ignore_index = ignore_index self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index)
[docs] def forward(self, preds, labels): logpt = -self.ce_fn(preds, labels) pt = torch.exp(logpt) if self.alpha is not None: logpt *= self.alpha loss = -((1 - pt) ** self.gamma) * logpt return loss