From f8911f440ed8d15042dcdd8ebe2f4d8c2fab4ebb Mon Sep 17 00:00:00 2001 From: Mikhail Moskovchenko Date: Thu, 5 Jun 2025 17:53:05 +0400 Subject: [PATCH] Replaced view with replace to prevent fails on non-contiguous tensors --- segmentation_models_pytorch/losses/dice.py | 12 ++++++------ segmentation_models_pytorch/losses/focal.py | 4 ++-- segmentation_models_pytorch/losses/jaccard.py | 12 ++++++------ segmentation_models_pytorch/losses/lovasz.py | 10 +++++----- segmentation_models_pytorch/losses/mcc.py | 4 ++-- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py index b8baae98..e660b740 100644 --- a/segmentation_models_pytorch/losses/dice.py +++ b/segmentation_models_pytorch/losses/dice.py @@ -73,8 +73,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: dims = (0, 2) if self.mode == BINARY_MODE: - y_true = y_true.view(bs, 1, -1) - y_pred = y_pred.view(bs, 1, -1) + y_true = y_true.reshape(bs, 1, -1) + y_pred = y_pred.reshape(bs, 1, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index @@ -82,8 +82,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_true = y_true * mask if self.mode == MULTICLASS_MODE: - y_true = y_true.view(bs, -1) - y_pred = y_pred.view(bs, num_classes, -1) + y_true = y_true.reshape(bs, -1) + y_pred = y_pred.reshape(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index @@ -98,8 +98,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: - y_true = y_true.view(bs, num_classes, -1) - y_pred = y_pred.view(bs, num_classes, -1) + y_true = y_true.reshape(bs, num_classes, -1) + y_pred = y_pred.reshape(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index diff --git a/segmentation_models_pytorch/losses/focal.py b/segmentation_models_pytorch/losses/focal.py index d26acb52..3beb9f34 100644 --- a/segmentation_models_pytorch/losses/focal.py +++ b/segmentation_models_pytorch/losses/focal.py @@ -57,8 +57,8 @@ def __init__( def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if self.mode in {BINARY_MODE, MULTILABEL_MODE}: - y_true = y_true.view(-1) - y_pred = y_pred.view(-1) + y_true = y_true.reshape(-1) + y_pred = y_pred.reshape(-1) if self.ignore_index is not None: # Filter predictions with ignore label from loss computation diff --git a/segmentation_models_pytorch/losses/jaccard.py b/segmentation_models_pytorch/losses/jaccard.py index 35727f95..0b7748f0 100644 --- a/segmentation_models_pytorch/losses/jaccard.py +++ b/segmentation_models_pytorch/losses/jaccard.py @@ -73,8 +73,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: dims = (0, 2) if self.mode == BINARY_MODE: - y_true = y_true.view(bs, 1, -1) - y_pred = y_pred.view(bs, 1, -1) + y_true = y_true.reshape(bs, 1, -1) + y_pred = y_pred.reshape(bs, 1, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index @@ -82,8 +82,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_true = y_true * mask if self.mode == MULTICLASS_MODE: - y_true = y_true.view(bs, -1) - y_pred = y_pred.view(bs, num_classes, -1) + y_true = y_true.reshape(bs, -1) + y_pred = y_pred.reshape(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index @@ -98,8 +98,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: - y_true = y_true.view(bs, num_classes, -1) - y_pred = y_pred.view(bs, num_classes, -1) + y_true = y_true.reshape(bs, num_classes, -1) + y_pred = y_pred.reshape(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index diff --git a/segmentation_models_pytorch/losses/lovasz.py b/segmentation_models_pytorch/losses/lovasz.py index 8bc35967..6dff5858 100644 --- a/segmentation_models_pytorch/losses/lovasz.py +++ b/segmentation_models_pytorch/losses/lovasz.py @@ -77,8 +77,8 @@ def _flatten_binary_scores(scores, labels, ignore=None): """Flattens predictions in the batch (binary case) Remove labels equal to 'ignore' """ - scores = scores.view(-1) - labels = labels.view(-1) + scores = scores.reshape(-1) + labels = labels.reshape(-1) if ignore is None: return scores, labels valid = labels != ignore @@ -151,13 +151,13 @@ def _flatten_probas(probas, labels, ignore=None): if probas.dim() == 3: # assumes output of a sigmoid layer B, H, W = probas.size() - probas = probas.view(B, 1, H, W) + probas = probas.reshape(B, 1, H, W) C = probas.size(1) probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] - probas = probas.contiguous().view(-1, C) # [P, C] + probas = probas.reshape(-1, C) # [P, C] - labels = labels.view(-1) + labels = labels.reshape(-1) if ignore is None: return probas, labels valid = labels != ignore diff --git a/segmentation_models_pytorch/losses/mcc.py b/segmentation_models_pytorch/losses/mcc.py index ebd7d669..65e47352 100644 --- a/segmentation_models_pytorch/losses/mcc.py +++ b/segmentation_models_pytorch/losses/mcc.py @@ -29,8 +29,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: bs = y_true.shape[0] - y_true = y_true.view(bs, 1, -1) - y_pred = y_pred.view(bs, 1, -1) + y_true = y_true.reshape(bs, 1, -1) + y_pred = y_pred.reshape(bs, 1, -1) tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps