Skip to content

Commit

Permalink
MNT Use regression data for check_sample_weight_invariance test on …
Browse files Browse the repository at this point in the history
…multioutput regression metrics (scikit-learn#30829)
  • Loading branch information
lucyleeow authored Feb 14, 2025
1 parent 2b97ac5 commit ebc1276
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,7 @@ def test_multiclass_sample_weight_invariance(name):
@pytest.mark.parametrize(
"name",
sorted(
(MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS | MULTIOUTPUT_METRICS)
(MULTILABELS_METRICS | THRESHOLDED_MULTILABEL_METRICS)
- METRICS_WITHOUT_SAMPLE_WEIGHT
),
)
Expand All @@ -1638,6 +1638,19 @@ def test_multilabel_sample_weight_invariance(name):
check_sample_weight_invariance(name, metric, y_true, y_pred)


@pytest.mark.parametrize(
"name",
sorted(MULTIOUTPUT_METRICS - METRICS_WITHOUT_SAMPLE_WEIGHT),
)
def test_multioutput_sample_weight_invariance(name):
random_state = check_random_state(0)
y_true = random_state.uniform(0, 2, size=(20, 5))
y_pred = random_state.uniform(0, 2, size=(20, 5))

metric = ALL_METRICS[name]
check_sample_weight_invariance(name, metric, y_true, y_pred)


def test_no_averaging_labels():
# test labels argument when not using averaging
# in multi-class and multi-label cases
Expand Down

0 comments on commit ebc1276

Please sign in to comment.