Implementation of Focal Loss using Pytorch

import torch
import torch.nn as nn


class BinaryFocalLossWithLogits(nn.Module):
    def __init__(self, gamma=0, alpha_dominant=None, reduce=True):
        super(BinaryFocalLossWithLogits, self).__init__()
        self.gamma = gamma
        self.alpha = alpha_dominant
        self.reduce = reduce
        self.BCELoss = nn.BCEWithLogitsLoss(reduction='mean' if self.reduce else 'none')

    def forward(self, inputs, labels):
        loss = self.BCELoss(inputs, labels)
        pt = torch.exp(-1 * loss)
        at = torch.tensor(list(map(lambda x: self.alpha if x == 1 else 1 - self.alpha, labels.view(-1).tolist())))
        loss = loss * ((1 - pt) ** self.gamma) * at
        if self.reduce:
            return torch.mean(loss)
        else:
            return loss
usage:
criterion = BinaryFocalLossWithLogits(gamma=2, alpha_dominant=0.1, reduce=False)
labels = torch.tensor([[1., 0., 0.]])
output = torch.tensor([[-0.9, 0.9, 0.8]])

댓글

이 블로그의 인기 게시물

sklearn tsne + matplotlib scatter