Skip to content

Commit

Permalink
fix: Experiment's df representation should use the names of Hypotheses
Browse files Browse the repository at this point in the history
  • Loading branch information
wookayin committed Nov 7, 2023
1 parent 9bca325 commit cff4ad6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 5 additions & 3 deletions expt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,11 +866,13 @@ def _replace(self, **kwargs) -> Experiment:

@property
def _df(self) -> pd.DataFrame:
hypotheses: List[Hypothesis] = list(self._hypotheses.values())

df = pd.DataFrame({
'name': list(self._hypotheses.keys()),
'hypothesis': list(self._hypotheses.values()),
'name': [h.name for h in hypotheses],
'hypothesis': hypotheses,
**{ # config keys (will be index)
k: [(h.config or {}).get(k) for h in self._hypotheses.values()]
k: [(h.config or {}).get(k) for h in hypotheses]
for k in self._config_keys
},
})
Expand Down
15 changes: 14 additions & 1 deletion expt/data_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for expt.data"""
# pylint: disable=redefined-outer-name
# pylint: disable=protected-access

import itertools
import re
Expand Down Expand Up @@ -679,9 +680,21 @@ def _validate_ex_gridsearch(self, ex: Experiment):
]
assert [ex.hypotheses[i].name for i in range(6)] == hypothesis_names

# Use custom name for some hypotheses
for h in ex.hypotheses:
assert h.config is not None
h.name = f"{h.config['algo']} ({h.config['env_id']})"

# Because df already has a multi-index, so should ex.
assert ex._df.index.names == ['algo', 'env_id', 'name'] # Note the order
assert list(ex._df.index.get_level_values('name')) == hypothesis_names
assert list(ex._df.index.get_level_values('name')) == [
'ppo (halfcheetah)',
'ppo (hopper)',
'ppo (humanoid)',
'sac (halfcheetah)',
'sac (hopper)',
'sac (humanoid)',
]
assert list(ex._df.index.get_level_values('algo')) == ( # ...
['ppo'] * 3 + ['sac'] * 3)
assert list(ex._df.index.get_level_values('env_id')) == (
Expand Down

0 comments on commit cff4ad6

Please sign in to comment.