We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
您好,我在阅读您的代码时,遇到了下面的问题:
def trans_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs): depth_loss_weights = kwargs.get("dlossw", None) total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) total_entropy = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]: prob_volume = stage_inputs["prob_volume"] depth_values = stage_inputs["depth_values"] depth_gt = depth_gt_ms[stage_key] mask = mask_ms[stage_key] mask = mask > 0.5 entropy_weight = 2.0 entro_loss, depth_entropy = entropy_loss(prob_volume, depth_gt, mask, depth_values) entro_loss = entro_loss * entropy_weight depth_loss = F.smooth_l1_loss(depth_entropy[mask], depth_gt[mask], reduction='mean') total_entropy += entro_loss if depth_loss_weights is not None: stage_idx = int(stage_key.replace("stage", "")) - 1 total_loss += depth_loss_weights[stage_idx] * entro_loss else: total_loss += entro_loss return total_loss, depth_loss, total_entropy, depth_entropy
如果我理解entropy_loss为交叉熵loss的话,那么entropy_weight应当类似于Focal Loss的样本权重,但似乎这个组合方式不太像Focal loss。是不是我理解错了? 如果您有空看到了该评论,希望能帮我解答一下该疑惑。非常感谢。
The text was updated successfully, but these errors were encountered:
楼主,我也是同样的疑惑,这里的depth_loss_weights应该指的是深度值的L1损失,但是最后却加了交叉熵损失
Sorry, something went wrong.
No branches or pull requests
您好,我在阅读您的代码时,遇到了下面的问题:
如果我理解entropy_loss为交叉熵loss的话,那么entropy_weight应当类似于Focal Loss的样本权重,但似乎这个组合方式不太像Focal loss。是不是我理解错了?
如果您有空看到了该评论,希望能帮我解答一下该疑惑。非常感谢。
The text was updated successfully, but these errors were encountered: