import torch
import torch.nn as nn
class BinaryFocalLossWithLogits(nn.Module):
def __init__(self, gamma=0, alpha_dominant=None, reduce=True):
super(BinaryFocalLossWithLogits, self).__init__()
self.gamma = gamma
self.alpha = alpha_dominant
self.reduce = reduce
self.BCELoss = nn.BCEWithLogitsLoss(reduction='mean' if self.reduce else 'none')
def forward(self, inputs, labels):
loss = self.BCELoss(inputs, labels)
pt = torch.exp(-1 * loss)
at = torch.tensor(list(map(lambda x: self.alpha if x == 1 else 1 - self.alpha, labels.view(-1).tolist())))
loss = loss * ((1 - pt) ** self.gamma) * at
if self.reduce:
return torch.mean(loss)
else:
return loss
usage:
criterion = BinaryFocalLossWithLogits(gamma=2, alpha_dominant=0.1, reduce=False)
labels = torch.tensor([[1., 0., 0.]])
output = torch.tensor([[-0.9, 0.9, 0.8]])
댓글
댓글 쓰기