Skip to content

Commit

Permalink
Refactor raw logits processing in DFine and RTDETR models by introduc…
Browse files Browse the repository at this point in the history
…ing a dedicated method for splitting and reshaping logits in explain mode.
  • Loading branch information
eugene123tw committed Jan 22, 2025
1 parent 01efc64 commit 285ecfe
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 31 deletions.
16 changes: 5 additions & 11 deletions src/otx/algo/detection/d_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,19 +347,13 @@ def _forward_explain_detection(
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

feature_vector = self.feature_vector_fn(backbone_feats)

splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]

# Permute and split logits in one line
raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
raw_logits = [
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
]
raw_logits = DETR.split_and_reshape_logits(
backbone_feats,
predictions["raw_logits"],
)

saliency_map = self.explain_fn(raw_logits)
feature_vector = self.feature_vector_fn(backbone_feats)
predictions.update(
{
"feature_vector": feature_vector,
Expand Down
33 changes: 24 additions & 9 deletions src/otx/algo/detection/detectors/detection_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,8 @@ def export(
deploy_mode=True,
)
if explain_mode:
raw_logits = self.split_and_reshape_logits(backbone_feats, predictions["raw_logits"])
feature_vector = self.feature_vector_fn(backbone_feats)
splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
# Permute and split logits in one line
raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
raw_logits = [
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1])
for logits, f in zip(raw_logits, backbone_feats)
]
saliency_map = self.explain_fn(raw_logits)
xai_output = {
"feature_vector": feature_vector,
Expand All @@ -121,6 +113,29 @@ def export(
results.update(xai_output) # type: ignore[union-attr]
return results

@staticmethod
def split_and_reshape_logits(
backbone_feats: tuple[Tensor, ...],
raw_logits: Tensor,
) -> tuple[Tensor, ...]:
"""Splits and reshapes raw logits for explain mode.
Args:
backbone_feats (tuple[Tensor,...]): Tuple of backbone features.
raw_logits (Tensor): Raw logits.
Returns:
tuple[Tensor,...]: The reshaped logits.
"""
splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
# Permute and split logits in one line
raw_logits = torch.split(raw_logits.permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
return tuple(
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
)

def postprocess(
self,
outputs: dict[str, Tensor],
Expand Down
16 changes: 5 additions & 11 deletions src/otx/algo/detection/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,13 @@ def _forward_explain_detection(
backbone_feats = self.encoder(self.backbone(entity.images))
predictions = self.decoder(backbone_feats, explain_mode=True)

feature_vector = self.feature_vector_fn(backbone_feats)

splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]

# Permute and split logits in one line
raw_logits = torch.split(predictions["raw_logits"].permute(0, 2, 1), splits, dim=-1)

# Reshape each split in a list comprehension
raw_logits = [
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
]
raw_logits = DETR.split_and_reshape_logits(
backbone_feats,
predictions["raw_logits"],
)

saliency_map = self.explain_fn(raw_logits)
feature_vector = self.feature_vector_fn(backbone_feats)
predictions.update(
{
"feature_vector": feature_vector,
Expand Down

0 comments on commit 285ecfe

Please sign in to comment.