Skip to content

Commit 7b0c907

Browse files
authored
Fixed a bug in calculating the metric for Binary labels
1 parent 7f2ca0a commit 7b0c907

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/anomalib/models/components/base/export_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float:
393393
setattr(batch, name, torch.from_numpy(pred))
394394
if batch.gt_mask is not None:
395395
batch.gt_mask = batch.gt_mask.unsqueeze(dim=1)
396+
batch.pred_score = batch.pred_score.squeeze(dim=1) # Squeezing since it is binary. (B, 1) -> (B)
396397
metric.update(batch)
397398
return metric.compute()
398399

0 commit comments

Comments
 (0)