From cff4ad64832a21693f51aec7b44ad855fd7f7e85 Mon Sep 17 00:00:00 2001 From: Jongwook Choi Date: Tue, 7 Nov 2023 00:59:30 -0500 Subject: [PATCH] fix: Experiment's df representation should use the names of Hypotheses --- expt/data.py | 8 +++++--- expt/data_test.py | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/expt/data.py b/expt/data.py index 1eaee6d..84beb24 100644 --- a/expt/data.py +++ b/expt/data.py @@ -866,11 +866,13 @@ def _replace(self, **kwargs) -> Experiment: @property def _df(self) -> pd.DataFrame: + hypotheses: List[Hypothesis] = list(self._hypotheses.values()) + df = pd.DataFrame({ - 'name': list(self._hypotheses.keys()), - 'hypothesis': list(self._hypotheses.values()), + 'name': [h.name for h in hypotheses], + 'hypothesis': hypotheses, **{ # config keys (will be index) - k: [(h.config or {}).get(k) for h in self._hypotheses.values()] + k: [(h.config or {}).get(k) for h in hypotheses] for k in self._config_keys }, }) diff --git a/expt/data_test.py b/expt/data_test.py index ee0bf6b..195d5ff 100644 --- a/expt/data_test.py +++ b/expt/data_test.py @@ -1,5 +1,6 @@ """Tests for expt.data""" # pylint: disable=redefined-outer-name +# pylint: disable=protected-access import itertools import re @@ -679,9 +680,21 @@ def _validate_ex_gridsearch(self, ex: Experiment): ] assert [ex.hypotheses[i].name for i in range(6)] == hypothesis_names + # Use custom name for some hypotheses + for h in ex.hypotheses: + assert h.config is not None + h.name = f"{h.config['algo']} ({h.config['env_id']})" + # Because df already has a multi-index, so should ex. assert ex._df.index.names == ['algo', 'env_id', 'name'] # Note the order - assert list(ex._df.index.get_level_values('name')) == hypothesis_names + assert list(ex._df.index.get_level_values('name')) == [ + 'ppo (halfcheetah)', + 'ppo (hopper)', + 'ppo (humanoid)', + 'sac (halfcheetah)', + 'sac (hopper)', + 'sac (humanoid)', + ] assert list(ex._df.index.get_level_values('algo')) == ( # ... ['ppo'] * 3 + ['sac'] * 3) assert list(ex._df.index.get_level_values('env_id')) == (