-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_05_loss.py
More file actions
36 lines (30 loc) · 1.04 KB
/
_05_loss.py
File metadata and controls
36 lines (30 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn.functional as F
def masked_mse_loss(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""
Compute the masked MSE loss between input and target.
"""
mask = mask.float()
loss = F.mse_loss(input * mask, target * mask, reduction="sum")
return loss / mask.sum()
def criterion_neg_log_bernoulli(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""
Compute the negative log-likelihood of Bernoulli distribution
"""
mask = mask.float()
bernoulli = torch.distributions.Bernoulli(probs=input)
masked_log_probs = bernoulli.log_prob((target > 0).float()) * mask
return -masked_log_probs.sum() / mask.sum()
def masked_relative_error(
input: torch.Tensor, target: torch.Tensor, mask: torch.LongTensor
) -> torch.Tensor:
"""
Compute the masked relative error between input and target.
"""
assert mask.any()
loss = torch.abs(input[mask] - target[mask]) / (target[mask] + 1e-6)
return loss.mean()