17 lines
472 B
Python
17 lines
472 B
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class SiLogLoss(nn.Module):
|
|
def __init__(self, lambd=0.5):
|
|
super().__init__()
|
|
self.lambd = lambd
|
|
|
|
def forward(self, pred, target, valid_mask):
|
|
valid_mask = valid_mask.detach()
|
|
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
|
|
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
|
|
self.lambd * torch.pow(diff_log.mean(), 2))
|
|
|
|
return loss
|