Skip to content

Commit 2f0b54b

Browse files
authored
DETR XAI (#4184)
* Implement explainability features in DFine and RTDETR models
1 parent 4416ac4 commit 2f0b54b

File tree

9 files changed

+205
-52
lines changed

9 files changed

+205
-52
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ All notable changes to this project will be documented in this file.
2424
(<https://github.com/openvinotoolkit/training_extensions/pull/4017>)
2525
- Add D-Fine Detection Algorithm
2626
(<https://github.com/openvinotoolkit/training_extensions/pull/4142>)
27+
- Add DETR XAI Explain Mode
28+
(<https://github.com/openvinotoolkit/training_extensions/pull/4184>)
2729

2830
### Enhancements
2931

docs/source/guide/tutorials/base/explain.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
3232
3333
(otx) ...$ otx explain --work_dir otx-workspace \
3434
--dump True # Wherether to save saliency map images or not
35+
--explain_config.postprocess True # Resizes and applies colormap to the saliency map
3536
3637
.. tab-item:: CLI (with config)
3738

@@ -41,6 +42,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
4142
--data_root data/wgisd \
4243
--checkpoint otx-workspace/20240312_051135/checkpoints/epoch_033.ckpt \
4344
--dump True # Wherether to save saliency map images or not
45+
--explain_config.postprocess True # Resizes and applies colormap to the saliency map
4446
4547
.. tab-item:: API
4648

@@ -49,7 +51,7 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
4951
engine.explain(
5052
checkpoint="<checkpoint-path>",
5153
datamodule=OTXDataModule(...), # The data module to use for predictions
52-
explain_config=ExplainConfig(postprocess=True),
54+
explain_config=ExplainConfig(postprocess=True), # Resizes and applies colormap to the saliency map
5355
dump=True # Wherether to save saliency map images or not
5456
)
5557

src/otx/algo/detection/d_fine.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def _customize_inputs(
157157
)
158158
targets.append({"boxes": scaled_bboxes, "labels": ll})
159159

160+
if self.explain_mode:
161+
return {"entity": entity}
162+
160163
return {
161164
"images": entity.images,
162165
"targets": targets,
@@ -185,6 +188,33 @@ def _customize_outputs(
185188
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
186189
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)
187190

191+
if self.explain_mode:
192+
if not isinstance(outputs, dict):
193+
msg = f"Model output should be a dict, but got {type(outputs)}."
194+
raise ValueError(msg)
195+
196+
if "feature_vector" not in outputs:
197+
msg = "No feature vector in the model output."
198+
raise ValueError(msg)
199+
200+
if "saliency_map" not in outputs:
201+
msg = "No saliency maps in the model output."
202+
raise ValueError(msg)
203+
204+
saliency_map = outputs["saliency_map"].detach().cpu().numpy()
205+
feature_vector = outputs["feature_vector"].detach().cpu().numpy()
206+
207+
return DetBatchPredEntity(
208+
batch_size=len(outputs),
209+
images=inputs.images,
210+
imgs_info=inputs.imgs_info,
211+
scores=scores,
212+
bboxes=bboxes,
213+
labels=labels,
214+
feature_vector=feature_vector,
215+
saliency_map=saliency_map,
216+
)
217+
188218
return DetBatchPredEntity(
189219
batch_size=len(outputs),
190220
images=inputs.images,
@@ -306,3 +336,29 @@ def _optimization_config(self) -> dict[str, Any]:
306336
},
307337
},
308338
}
339+
340+
@staticmethod
341+
def _forward_explain_detection(
342+
self, # noqa: ANN001
343+
entity: DetBatchDataEntity,
344+
mode: str = "tensor", # noqa: ARG004
345+
) -> dict[str, torch.Tensor]:
346+
"""Forward function for explainable detection model."""
347+
backbone_feats = self.encoder(self.backbone(entity.images))
348+
predictions = self.decoder(backbone_feats, explain_mode=True)
349+
350+
raw_logits = DETR.split_and_reshape_logits(
351+
backbone_feats,
352+
predictions["raw_logits"],
353+
)
354+
355+
saliency_map = self.explain_fn(raw_logits)
356+
feature_vector = self.feature_vector_fn(backbone_feats)
357+
predictions.update(
358+
{
359+
"feature_vector": feature_vector,
360+
"saliency_map": saliency_map,
361+
},
362+
)
363+
364+
return predictions

src/otx/algo/detection/detectors/detection_transformer.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
"""Base DETR model implementations."""
55

66
from __future__ import annotations
77

8-
import warnings
98
from typing import Any
109

1110
import numpy as np
@@ -96,22 +95,47 @@ def export(
9695
explain_mode: bool = False,
9796
) -> dict[str, Any] | tuple[list[Any], list[Any], list[Any]]:
9897
"""Exports the model."""
98+
backbone_feats = self.encoder(self.backbone(batch_inputs))
99+
predictions = self.decoder(backbone_feats, explain_mode=True)
99100
results = self.postprocess(
100-
self._forward_features(batch_inputs),
101+
predictions,
101102
[meta["img_shape"] for meta in batch_img_metas],
102103
deploy_mode=True,
103104
)
104-
105105
if explain_mode:
106-
# TODO(Eugene): Implement explain mode for DETR model.
107-
warnings.warn("Explain mode is not supported for DETR model. Return dummy values.", stacklevel=2)
106+
raw_logits = self.split_and_reshape_logits(backbone_feats, predictions["raw_logits"])
107+
feature_vector = self.feature_vector_fn(backbone_feats)
108+
saliency_map = self.explain_fn(raw_logits)
108109
xai_output = {
109-
"feature_vector": torch.zeros(1, 1),
110-
"saliency_map": torch.zeros(1),
110+
"feature_vector": feature_vector,
111+
"saliency_map": saliency_map,
111112
}
112113
results.update(xai_output) # type: ignore[union-attr]
113114
return results
114115

116+
@staticmethod
117+
def split_and_reshape_logits(
118+
backbone_feats: tuple[Tensor, ...],
119+
raw_logits: Tensor,
120+
) -> tuple[Tensor, ...]:
121+
"""Splits and reshapes raw logits for explain mode.
122+
123+
Args:
124+
backbone_feats (tuple[Tensor,...]): Tuple of backbone features.
125+
raw_logits (Tensor): Raw logits.
126+
127+
Returns:
128+
tuple[Tensor,...]: The reshaped logits.
129+
"""
130+
splits = [f.shape[-2] * f.shape[-1] for f in backbone_feats]
131+
# Permute and split logits in one line
132+
raw_logits = torch.split(raw_logits.permute(0, 2, 1), splits, dim=-1)
133+
134+
# Reshape each split in a list comprehension
135+
return tuple(
136+
logits.reshape(f.shape[0], -1, f.shape[-2], f.shape[-1]) for logits, f in zip(raw_logits, backbone_feats)
137+
)
138+
115139
def postprocess(
116140
self,
117141
outputs: dict[str, Tensor],

src/otx/algo/detection/heads/dfine_decoder.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def _get_decoder_input(
723723
enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1)
724724
content = torch.concat([denoising_logits, content], dim=1)
725725

726-
return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list
726+
return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list, enc_outputs_logits
727727

728728
def _select_topk(
729729
self,
@@ -762,8 +762,22 @@ def _select_topk(
762762

763763
return topk_memory, topk_logits, topk_anchors
764764

765-
def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None) -> dict[str, Tensor]:
766-
"""Forward pass of the DFine Transformer module."""
765+
def forward(
766+
self,
767+
feats: Tensor,
768+
targets: list[dict[str, Tensor]] | None = None,
769+
explain_mode: bool = False,
770+
) -> dict[str, Tensor]:
771+
"""Forward function of the D-FINE Decoder Transformer Module.
772+
773+
Args:
774+
feats (Tensor): Feature maps.
775+
targets (list[dict[str, Tensor]] | None, optional): target annotations. Defaults to None.
776+
explain_mode (bool, optional): Whether to return raw logits for explanation. Defaults to False.
777+
778+
Returns:
779+
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
780+
"""
767781
# input projection and embedding
768782
memory, spatial_shapes = self._get_encoder_input(feats)
769783

@@ -781,7 +795,13 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
781795
else:
782796
denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
783797

784-
init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input(
798+
(
799+
init_ref_contents,
800+
init_ref_points_unact,
801+
enc_topk_bboxes_list,
802+
enc_topk_logits_list,
803+
raw_logits,
804+
) = self._get_decoder_input(
785805
memory,
786806
spatial_shapes,
787807
denoising_logits,
@@ -858,6 +878,9 @@ def forward(self, feats: Tensor, targets: list[dict[str, Tensor]] | None = None)
858878
"pred_boxes": out_bboxes[-1],
859879
}
860880

881+
if explain_mode:
882+
out["raw_logits"] = raw_logits
883+
861884
return out
862885

863886
@torch.jit.unused

src/otx/algo/detection/heads/rtdetr_decoder.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
"""RTDETR decoder, modified from https://github.com/lyuwenyu/RT-DETR."""
@@ -546,10 +546,10 @@ def _get_decoder_input(
546546

547547
output_memory = self.enc_output(memory)
548548

549-
enc_outputs_class = self.enc_score_head(output_memory)
549+
enc_outputs_logits = self.enc_score_head(output_memory)
550550
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
551551

552-
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
552+
_, topk_ind = torch.topk(enc_outputs_logits.max(-1).values, self.num_queries, dim=1)
553553

554554
reference_points_unact = enc_outputs_coord_unact.gather(
555555
dim=1,
@@ -560,9 +560,9 @@ def _get_decoder_input(
560560
if denoising_bbox_unact is not None:
561561
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
562562

563-
enc_topk_logits = enc_outputs_class.gather(
563+
enc_topk_logits = enc_outputs_logits.gather(
564564
dim=1,
565-
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]),
565+
index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_logits.shape[-1]),
566566
)
567567

568568
# extract region features
@@ -575,10 +575,24 @@ def _get_decoder_input(
575575
if denoising_class is not None:
576576
target = torch.concat([denoising_class, target], 1)
577577

578-
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
578+
return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits, enc_outputs_logits
579579

580-
def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None) -> torch.Tensor:
581-
"""Forward pass of the RTDETRTransformer module."""
580+
def forward(
581+
self,
582+
feats: torch.Tensor,
583+
targets: list[dict[str, torch.Tensor]] | None = None,
584+
explain_mode: bool = False,
585+
) -> dict[str, torch.Tensor]:
586+
"""Forward function of RTDETRTransformer.
587+
588+
Args:
589+
feats (Tensor): Input features.
590+
targets (List[Dict[str, Tensor]]): List of target dictionaries.
591+
explain_mode (bool): Whether to return raw logits for explanation.
592+
593+
Returns:
594+
dict[str, Tensor]: Output dictionary containing predicted logits, losses and boxes.
595+
"""
582596
# input projection and embedding
583597
(memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats)
584598

@@ -596,7 +610,7 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
596610
else:
597611
denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
598612

599-
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = self._get_decoder_input(
613+
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits, raw_logits = self._get_decoder_input(
600614
memory,
601615
spatial_shapes,
602616
denoising_class,
@@ -630,6 +644,9 @@ def forward(self, feats: torch.Tensor, targets: list[dict[str, torch.Tensor]] |
630644
out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
631645
out["dn_meta"] = dn_meta
632646

647+
if explain_mode:
648+
out["raw_logits"] = raw_logits
649+
633650
return out
634651

635652
@torch.jit.unused

src/otx/algo/detection/rtdetr.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
"""RTDetr model implementations."""
@@ -128,6 +128,9 @@ def _customize_inputs(
128128
)
129129
targets.append({"boxes": scaled_bboxes, "labels": ll})
130130

131+
if self.explain_mode:
132+
return {"entity": entity}
133+
131134
return {
132135
"images": entity.images,
133136
"targets": targets,
@@ -156,6 +159,33 @@ def _customize_outputs(
156159
original_sizes = [img_info.ori_shape for img_info in inputs.imgs_info]
157160
scores, bboxes, labels = self.model.postprocess(outputs, original_sizes)
158161

162+
if self.explain_mode:
163+
if not isinstance(outputs, dict):
164+
msg = f"Model output should be a dict, but got {type(outputs)}."
165+
raise ValueError(msg)
166+
167+
if "feature_vector" not in outputs:
168+
msg = "No feature vector in the model output."
169+
raise ValueError(msg)
170+
171+
if "saliency_map" not in outputs:
172+
msg = "No saliency maps in the model output."
173+
raise ValueError(msg)
174+
175+
saliency_map = outputs["saliency_map"].detach().cpu().numpy()
176+
feature_vector = outputs["feature_vector"].detach().cpu().numpy()
177+
178+
return DetBatchPredEntity(
179+
batch_size=len(outputs),
180+
images=inputs.images,
181+
imgs_info=inputs.imgs_info,
182+
scores=scores,
183+
bboxes=bboxes,
184+
labels=labels,
185+
feature_vector=feature_vector,
186+
saliency_map=saliency_map,
187+
)
188+
159189
return DetBatchPredEntity(
160190
batch_size=len(outputs),
161191
images=inputs.images,
@@ -271,3 +301,29 @@ def _exporter(self) -> OTXModelExporter:
271301
def _optimization_config(self) -> dict[str, Any]:
272302
"""PTQ config for RT-DETR."""
273303
return {"model_type": "transformer"}
304+
305+
@staticmethod
306+
def _forward_explain_detection(
307+
self, # noqa: ANN001
308+
entity: DetBatchDataEntity,
309+
mode: str = "tensor", # noqa: ARG004
310+
) -> dict[str, torch.Tensor]:
311+
"""Forward function for explainable detection model."""
312+
backbone_feats = self.encoder(self.backbone(entity.images))
313+
predictions = self.decoder(backbone_feats, explain_mode=True)
314+
315+
raw_logits = DETR.split_and_reshape_logits(
316+
backbone_feats,
317+
predictions["raw_logits"],
318+
)
319+
320+
saliency_map = self.explain_fn(raw_logits)
321+
feature_vector = self.feature_vector_fn(backbone_feats)
322+
predictions.update(
323+
{
324+
"feature_vector": feature_vector,
325+
"saliency_map": saliency_map,
326+
},
327+
)
328+
329+
return predictions

0 commit comments

Comments
 (0)