Skip to content

Commit

Permalink
Fix size attribute error for scalar outputs in precision/recall/f1 me…
Browse files Browse the repository at this point in the history
…trics
  • Loading branch information
Maxwell-Jia committed Dec 22, 2024
1 parent 55f1bc6 commit ab599cc
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion metrics/f1/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b
score = f1_score(
references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight
)
return {"f1": float(score) if score.size == 1 else score}
return {"f1": score if getattr(score, 'size', 1) > 1 else float(score)}
2 changes: 1 addition & 1 deletion metrics/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,4 @@ def _compute(
sample_weight=sample_weight,
zero_division=zero_division,
)
return {"precision": float(score) if score.size == 1 else score}
return {"precision": score if getattr(score, 'size', 1) > 1 else float(score)}
2 changes: 1 addition & 1 deletion metrics/recall/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def _compute(
sample_weight=sample_weight,
zero_division=zero_division,
)
return {"recall": float(score) if score.size == 1 else score}
return {"recall": score if getattr(score, 'size', 1) > 1 else float(score)}

0 comments on commit ab599cc

Please sign in to comment.