Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
Arseniy Belkov committed Mar 22, 2024
1 parent 06d4cdc commit f5a8e0b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
16 changes: 11 additions & 5 deletions tests/callbacks/test_metric_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,11 @@ def test_empty_identity():

from thunder.callbacks.metric_monitor import _identity

preproc1 = lambda y, x: (y * 2, x)
preproc2 = lambda y, x: (y, x * 2)
def preproc1(y, x):
return y * 2, x

def preproc2(y, x):
return y, x * 2

group_metrics = {preproc1: accuracy_score, preproc2: {
"accuracy2": accuracy_score,
Expand All @@ -481,11 +484,14 @@ def test_empty_identity():

metric_monitor = MetricMonitor(group_metrics=group_metrics)

assert sorted(metric_monitor.group_metrics.keys()) == sorted(["accuracy_score", "accuracy2", "accuracy3", "accuracy4"])
assert sorted(metric_monitor.group_metrics.keys()) == \
sorted(["accuracy_score", "accuracy2", "accuracy3", "accuracy4"])
assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2, _identity]

group_metrics.pop("accuracy4")
metric_monitor = MetricMonitor(group_metrics=group_metrics)

assert sorted(metric_monitor.group_metrics.keys()) == sorted(["accuracy_score", "accuracy2", "accuracy3"])
assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2], len(metric_monitor.group_preprocess.keys())
assert sorted(metric_monitor.group_metrics.keys()) == \
sorted(["accuracy_score", "accuracy2", "accuracy3"])
assert list(metric_monitor.group_preprocess.keys()) == \
[preproc1, preproc2], len(metric_monitor.group_preprocess.keys())
2 changes: 1 addition & 1 deletion thunder/callbacks/metric_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from toolz import compose, keymap, valmap

from ..torch.utils import to_np
from ..utils import collect, squeeze_first
from ..utils import squeeze_first


class MetricMonitor(Callback):
Expand Down

0 comments on commit f5a8e0b

Please sign in to comment.