-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
82 lines (65 loc) · 2.96 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import torch.nn.functional as F
class SPLC(nn.Module):
r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "
.. math::
&L_{SPLC}^+ = loss^+(p)
&L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)
where :math:'\tau' is a threshold to identify missing label
:math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function,
:math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.
.. note::
SPLC can be combinded with various multi-label loss functions.
SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
Sigmoid will be done in loss.
Args:
tau (float): threshold value. Default: 0.6
change_epoch (int): which epoch to combine SPLC. Default: 1
margin (float): Margin value. Default: 1
gamma (float): Hard mining value. Default: 2
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``
"""
def __init__(
self,
tau: float = 0.6,
change_epoch: int = 1,
margin: float = 1.0,
gamma: float = 2.0,
) -> None:
super(SPLC, self).__init__()
self.tau = tau
self.change_epoch = change_epoch
self.margin = margin
self.gamma = gamma
def forward(self, logits: torch.Tensor, targets: torch.Tensor,
epoch):
"""
call function as forward
Args:
logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
targets : Multi-label binarized vector with shape of :math:`(N, C)`
epoch : The epoch of current training.
Returns:
torch.Tensor: loss
"""
# Subtract margin for positive logits
logits = torch.where(targets == 1, logits - self.margin, logits)
# SPLC missing label correction
if epoch >= self.change_epoch:
targets = torch.where(
torch.sigmoid(logits) > self.tau,
torch.tensor(1).cuda(), targets)
pred = torch.sigmoid(logits)
# Focal margin for postive loss
pt = (1 - pred) * targets + pred * (1 - targets)
focal_weight = pt**self.gamma
los_pos = targets * F.logsigmoid(logits)
los_neg = (1 - targets) * F.logsigmoid(-logits)
loss = -(los_pos + los_neg)
loss *= focal_weight
return loss.sum(), targets