Skip to content

Commit

Permalink
Merge branch 'neuro-ml:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
kirpoly authored Mar 22, 2024
2 parents 9f7e1a1 + f5a8e0b commit 069ab05
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![codecov](https://codecov.io/gh/neuro-ml/thunder/branch/master/graph/badge.svg)](https://codecov.io/gh/neuro-ml/thunder)
[![pypi](https://img.shields.io/pypi/v/thunder?logo=pypi&label=PyPi)](https://pypi.org/project/thunder/)

> You saw the lightning. Now it's time to hear the thunder 🌩️
> _You saw the lightning. Now it's time to hear the thunder_ 🌩️
# Thunder 🌩️

Expand Down
2 changes: 1 addition & 1 deletion docs/callbacks/metric_monitor.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ flag. Being set to `True` it forces the callback to store table of metrics in th
| batch_idxn_m | some_value | some_value |

For each set (e.g. `val`, `test`) and each `dataloader_idx`, MetricMonitor stores separate table.
By default aforementioned tables are saved to `default_root_dir` of lightning's Trainer, in the format of
By default aforementioned tables are saved to `trainer.log_dir` in the format of
`set_name/dataloader_idx.csv` (e.g. `val/dataloader_0.csv`).
If loggers you use have method `log_table` (e.g. `WandbLogger`),
then this method will receive key and each table in the format of `pd.DataFrame`.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
> You saw the lightning. Now it's time to hear the thunder 🌩️
> _You saw the lightning. Now it's time to hear the thunder_ 🌩️
# Thunder

Expand Down
3 changes: 2 additions & 1 deletion tests/assets/dumb.config
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
from lightning import Trainer
from lightning.pytorch.demos.boring_classes import RandomDataset
from torch import nn
from torch.utils.data import DataLoader

import torch
from thunder import ThunderModule
from thunder.layout import Single

train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=1)

layout = Single()

Expand Down
34 changes: 34 additions & 0 deletions tests/callbacks/test_metric_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,37 @@ def test_log_table(tmpdir):
for p in Path(root_dir).glob("*/dataloader_*"):
assert str(p.relative_to(root_dir)) not in ["val/dataloader_0", "val/dataloader_1",
"test/dataloader_0", "test/dataloader_1"]


def test_empty_identity():
"""
If all group metrics are specified with prerprocessing,
there should be no identity preprocessing
"""

from thunder.callbacks.metric_monitor import _identity

def preproc1(y, x):
return y * 2, x

def preproc2(y, x):
return y, x * 2

group_metrics = {preproc1: accuracy_score, preproc2: {
"accuracy2": accuracy_score,
"accuracy3": accuracy_score,
}, "accuracy4": accuracy_score}

metric_monitor = MetricMonitor(group_metrics=group_metrics)

assert sorted(metric_monitor.group_metrics.keys()) == \
sorted(["accuracy_score", "accuracy2", "accuracy3", "accuracy4"])
assert list(metric_monitor.group_preprocess.keys()) == [preproc1, preproc2, _identity]

group_metrics.pop("accuracy4")
metric_monitor = MetricMonitor(group_metrics=group_metrics)

assert sorted(metric_monitor.group_metrics.keys()) == \
sorted(["accuracy_score", "accuracy2", "accuracy3"])
assert list(metric_monitor.group_preprocess.keys()) == \
[preproc1, preproc2], len(metric_monitor.group_preprocess.keys())
22 changes: 13 additions & 9 deletions thunder/callbacks/metric_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from toolz import compose, keymap, valmap

from ..torch.utils import to_np
from ..utils import collect, squeeze_first
from ..utils import squeeze_first


class MetricMonitor(Callback):
Expand All @@ -33,14 +33,17 @@ def __init__(
group_metrics: Dict
Metrics that are calculated on entire dataset.
aggregate_fn: Union[Dict[str, Callable], str, Callable, List[Union[str, Callable]]]
How to aggregate metrics. By default it computes mean value. If yoy specify something,
How to aggregate metrics. By default, it computes mean value. If yoy specify something,
then the callback will compute mean and the specified values.
log_individual_metrics: bool
If True, logs table for case-wise metrics (if logger has `log_table` method) and saves table to csv file.
"""
_single_metrics = dict(single_metrics or {})
_group_metrics = dict(group_metrics or {})

# metrics = {"metric_name": func}
# preprocess = {preprocess_func: ["metric_name"]} + {identity: ["metric_name"]}

single_metrics, single_preprocess = _process_metrics(_single_metrics)
group_metrics, group_preprocess = _process_metrics(_group_metrics)

Expand Down Expand Up @@ -89,11 +92,11 @@ def __init__(
elif isinstance(aggregate_fn, dict):
not_callable = dict(filter(lambda it: not callable(it[1]), aggregate_fn.items()))
if not_callable:
raise TypeError(f"All aggregators must be callable if you pass a dict, got uncallable {not_callable}")
raise TypeError(f"All aggregators must be callable if you pass a dict, got not callable {not_callable}")
self.aggregate_fn.update(aggregate_fn)
else:
if aggregate_fn is not None:
raise ValueError(f"Unknown type of aggrefate_fn: {type(aggregate_fn)}")
raise ValueError(f"Unknown type of aggregate_fn: {type(aggregate_fn)}")

def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
Expand Down Expand Up @@ -128,7 +131,7 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> No
for k, vs in group.items():
pl_module.log(f'train/{k}', np.mean(vs))

self._train_losses = []
self._train_losses.clear()

def on_validation_batch_end(
self,
Expand Down Expand Up @@ -209,7 +212,7 @@ def evaluate_epoch(self, trainer: Trainer, pl_module: LightningModule, key: str)

if self.log_individual_metrics:
dataframe = pd.DataFrame(metrics)
root_dir = Path(trainer.default_root_dir) / key # trainer.log_dir / key ?
root_dir = Path(trainer.log_dir) / key
root_dir.mkdir(exist_ok=True)
for logger in pl_module.loggers:
if hasattr(logger, "log_table"):
Expand Down Expand Up @@ -252,9 +255,8 @@ def _identity(*args):
return squeeze_first(args)


@collect
def _recombine_batch(xs: Sequence) -> List:
yield from map(squeeze_first, zip_equal(*xs))
return [squeeze_first(x) for x in zip_equal(*xs)]


def _process_metrics(raw_metrics: Dict) -> Tuple[Dict[str, Callable], Dict[Callable, List[str]]]:
Expand Down Expand Up @@ -285,5 +287,7 @@ def _process_metrics(raw_metrics: Dict) -> Tuple[Dict[str, Callable], Dict[Calla
else:
raise TypeError(f"Metric keys should be of type str or Callable, got {type(k)}")

preprocess[_identity] = sorted(set(processed_metrics.keys()) - set(chain.from_iterable(preprocess.values())))
identity_preprocess_metrics = sorted(set(processed_metrics.keys()) - set(chain.from_iterable(preprocess.values())))
if identity_preprocess_metrics:
preprocess[_identity] = identity_preprocess_metrics
return processed_metrics, preprocess
3 changes: 2 additions & 1 deletion thunder/callbacks/time_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def log_batch_size(self, batch, key: str) -> None:
def compute_time_delta(self) -> Dict[str, float]:
deltas = {}
for key, time_stamps in self.time_stamps.items():
deltas[key] = [(t[1] - t[0]).total_seconds() for t in windowed(time_stamps, 2, step=2)]
deltas[key] = [(t[1] - t[0]).total_seconds() for t in windowed(time_stamps, 2, step=2,
fillvalue=time_stamps[-1])]
deltas[key] = sum(deltas[key]) / len(deltas[key])

if "train epoch" in deltas:
Expand Down
6 changes: 5 additions & 1 deletion thunder/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ def run(
if names is not None:
names = names.split(',')
backend, config = BackendCommand.get_backend(backend, kwargs)
backend.run(config, Path(experiment).absolute(), get_nodes(experiment, names))
experiment = Path(experiment).absolute()
if not experiment.exists():
raise ValueError(f"Trying to run experiment from folder {experiment}, "
"but it does not exist.")
backend.run(config, experiment, get_nodes(experiment, names))


@app.command(cls=BackendCommand)
Expand Down
2 changes: 1 addition & 1 deletion thunder/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[LRScheduler]]:

_optimizers = list(collapse([self.optimizer]))
_lr_schedulers = list(collapse([self.lr_scheduler]))
max_len = max(map(len, (_optimizers, _lr_schedulers)))
max_len = max(len(_optimizers), len(_lr_schedulers))
_optimizers = list(padded(_optimizers, None, max_len))
_lr_schedulers = list(padded(_lr_schedulers, None, max_len))

Expand Down

0 comments on commit 069ab05

Please sign in to comment.