Implementation of Focal Loss using Pytorch
import torch import torch.nn as nn class BinaryFocalLossWithLogits(nn.Module): def __init__ ( self , gamma= 0 , alpha_dominant= None, reduce= True ): super (BinaryFocalLossWithLogits , self ). __init__ () self .gamma = gamma self .alpha = alpha_dominant self .reduce = reduce self .BCELoss = nn.BCEWithLogitsLoss( reduction = 'mean' if self .reduce else 'none' ) def forward ( self , inputs , labels): loss = self .BCELoss(inputs , labels) pt = torch.exp(- 1 * loss) at = torch.tensor( list ( map ( lambda x: self .alpha if x == 1 else 1 - self .alpha , labels.view(- 1 ).tolist()))) loss = loss * (( 1 - pt) ** self .gamma) * at if self .reduce: return torch.mean(loss) else : return loss usage: criterion = BinaryFocalLossWithLogits( gamma = 2 , alpha_dominant = 0.1 , reduce = False ) labels = torch.tensor([[ 1. , 0. , 0. ]]) output = torch.tensor([[- 0...