Temp/Depth-Anything-V2-main/metric_depth/util/loss.py

17 lines
472 B
Python
Raw Normal View History

2025-07-03 13:48:55 +08:00
import torch
from torch import nn
class SiLogLoss(nn.Module):
def __init__(self, lambd=0.5):
super().__init__()
self.lambd = lambd
def forward(self, pred, target, valid_mask):
valid_mask = valid_mask.detach()
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
self.lambd * torch.pow(diff_log.mean(), 2))
return loss