From 285ecfec802f64bf181ad5df0979f6db609da2c1 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Wed, 22 Jan 2025 16:16:09 +0000 Subject: [PATCH] Refactor raw logits processing in DFine and RTDETR models by introducing a dedicated method for splitting and reshaping logits in explain mode. --- src/otx/algo/detection/d_fine.py | 16 +++------ .../detectors/detection_transformer.py | 33 ++++++++++++++----- src/otx/algo/detection/rtdetr.py | 16 +++------ 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/otx/algo/detection/d_fine.py b/src/otx/algo/detection/d_fine.py index 1ff1d6bb53..717ea9d6b2 100644 --- a/src/otx/algo/detection/d_fine.py +++ b/src/otx/algo/detection/d_fine.py @@ -347,19 +347,13 @@ def _forward_explain_detection( backbone_feats = self.encoder(self.backbone(entity.images)) predictions = self.decoder(backbone_feats, explain_mode=True) - feature_vector = self.feature_vector_fn(backbone_feats) - - splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] - - # Permute and split logits in one line - raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1) - - # Reshape each split in a list comprehension - raw_logits = [ - logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats) - ] + raw_logits = DETR.split_and_reshape_logits( + backbone_feats, + predictions["raw_logits"], + ) saliency_map = self.explain_fn(raw_logits) + feature_vector = self.feature_vector_fn(backbone_feats) predictions.update( { "feature_vector": feature_vector, diff --git a/src/otx/algo/detection/detectors/detection_transformer.py b/src/otx/algo/detection/detectors/detection_transformer.py index d0b12da5e0..f3cda5b741 100644 --- a/src/otx/algo/detection/detectors/detection_transformer.py +++ b/src/otx/algo/detection/detectors/detection_transformer.py @@ -103,16 +103,8 @@ def export( deploy_mode=True, ) if explain_mode: + raw_logits = self.split_and_reshape_logits(backbone_feats, predictions["raw_logits"]) feature_vector = self.feature_vector_fn(backbone_feats) - splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] - # Permute and split logits in one line - raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1) - - # Reshape each split in a list comprehension - raw_logits = [ - logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) - for logits, f in zip(raw_logits, backbone_feats) - ] saliency_map = self.explain_fn(raw_logits) xai_output = { "feature_vector": feature_vector, @@ -121,6 +113,29 @@ def export( results.update(xai_output) # type: ignore[union-attr] return results + @staticmethod + def split_and_reshape_logits( + backbone_feats: tuple[Tensor, ...], + raw_logits: Tensor, + ) -> tuple[Tensor, ...]: + """Splits and reshapes raw logits for explain mode. + + Args: + backbone_feats (tuple[Tensor,...]): Tuple of backbone features. + raw_logits (Tensor): Raw logits. + + Returns: + tuple[Tensor,...]: The reshaped logits. + """ + splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] + # Permute and split logits in one line + raw_logits = torch.split(raw_logits.permute(0, 2, 1), splits, dim=-1) + + # Reshape each split in a list comprehension + return tuple( + logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats) + ) + def postprocess( self, outputs: dict[str, Tensor], diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py index 9319ca6db8..fcbf6330c2 100644 --- a/src/otx/algo/detection/rtdetr.py +++ b/src/otx/algo/detection/rtdetr.py @@ -312,19 +312,13 @@ def _forward_explain_detection( backbone_feats = self.encoder(self.backbone(entity.images)) predictions = self.decoder(backbone_feats, explain_mode=True) - feature_vector = self.feature_vector_fn(backbone_feats) - - splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats] - - # Permute and split logits in one line - raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1) - - # Reshape each split in a list comprehension - raw_logits = [ - logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats) - ] + raw_logits = DETR.split_and_reshape_logits( + backbone_feats, + predictions["raw_logits"], + ) saliency_map = self.explain_fn(raw_logits) + feature_vector = self.feature_vector_fn(backbone_feats) predictions.update( { "feature_vector": feature_vector,