Skip to content

Commit

Permalink
fix group metrics memory leak (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
Arseniy Belkov committed Mar 22, 2024
1 parent dce8c0f commit 06d4cdc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
28 changes: 28 additions & 0 deletions tests/callbacks/test_metric_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,31 @@ def test_log_table(tmpdir):
for p in Path(root_dir).glob("*/dataloader_*"):
assert str(p.relative_to(root_dir)) not in ["val/dataloader_0", "val/dataloader_1",
"test/dataloader_0", "test/dataloader_1"]


def test_empty_identity():
"""
If all group metrics are specified with prerprocessing,
there should be no identity preprocessing
"""

from thunder.callbacks.metric_monitor import _identity

preproc1 = lambda y, x: (y * 2, x)
preproc2 = lambda y, x: (y, x * 2)

group_metrics = {preproc1: accuracy_score, preproc2: {
"accuracy2": accuracy_score,
"accuracy3": accuracy_score,
}, "accuracy4": accuracy_score}

metric_monitor = MetricMonitor(group_metrics=group_metrics)

assert sorted(metric_monitor.group_metrics.keys()) == sorted(["accuracy_score", "accuracy2", "accuracy3", "accuracy4"])
assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2, _identity]

group_metrics.pop("accuracy4")
metric_monitor = MetricMonitor(group_metrics=group_metrics)

assert sorted(metric_monitor.group_metrics.keys()) == sorted(["accuracy_score", "accuracy2", "accuracy3"])
assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2], len(metric_monitor.group_preprocess.keys())
10 changes: 7 additions & 3 deletions thunder/callbacks/metric_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(
_single_metrics = dict(single_metrics or {})
_group_metrics = dict(group_metrics or {})

# metrics = {"metric_name": func}
# preprocess = {preprocess_func: ["metric_name"]} + {identity: ["metric_name"]}

single_metrics, single_preprocess = _process_metrics(_single_metrics)
group_metrics, group_preprocess = _process_metrics(_group_metrics)

Expand Down Expand Up @@ -252,9 +255,8 @@ def _identity(*args):
return squeeze_first(args)


@collect
def _recombine_batch(xs: Sequence) -> List:
yield from map(squeeze_first, zip_equal(*xs))
return [squeeze_first(x) for x in zip_equal(*xs)]


def _process_metrics(raw_metrics: Dict) -> Tuple[Dict[str, Callable], Dict[Callable, List[str]]]:
Expand Down Expand Up @@ -285,5 +287,7 @@ def _process_metrics(raw_metrics: Dict) -> Tuple[Dict[str, Callable], Dict[Calla
else:
raise TypeError(f"Metric keys should be of type str or Callable, got {type(k)}")

preprocess[_identity] = sorted(set(processed_metrics.keys()) - set(chain.from_iterable(preprocess.values())))
identity_preprocess_metrics = sorted(set(processed_metrics.keys()) - set(chain.from_iterable(preprocess.values())))
if identity_preprocess_metrics:
preprocess[_identity] = identity_preprocess_metrics
return processed_metrics, preprocess

0 comments on commit 06d4cdc

Please sign in to comment.