From 3e9ccecf3d93f32e56ea0bad318681883ce3b0d1 Mon Sep 17 00:00:00 2001 From: Jongwook Choi Date: Thu, 9 Nov 2023 21:55:34 -0500 Subject: [PATCH] Experiment.summary() should show all columns when summary_columns=None --- expt/data.py | 14 ++++++++++---- expt/data_test.py | 6 ++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/expt/data.py b/expt/data.py index cc44e3f..9ad187e 100644 --- a/expt/data.py +++ b/expt/data.py @@ -886,18 +886,24 @@ def _df(self) -> pd.DataFrame: }, }) - if self._summary_columns: - # TODO: h.summary is expensive and slow, cache it + def _append_summary(summary_columns): + nonlocal df df = pd.concat([ df, pd.DataFrame({ k: [ - h.summary(columns=self._summary_columns).loc[0, k] + # TODO: h.summary is expensive and slow, cache it + h.summary(columns=summary_columns).loc[0, k] for h in self._hypotheses.values() - ] for k in self._summary_columns + ] for k in summary_columns }), ], axis=1) # yapf: disable + if self._summary_columns is not None: + _append_summary(summary_columns=self._summary_columns) + else: + _append_summary(summary_columns=self.columns) + # Need to sort index w.r.t the multi-index level hierarchy, because # the order of hypotheses being added is not guaranteed df = df.set_index([*self._config_keys, 'name']).sort_index() diff --git a/expt/data_test.py b/expt/data_test.py index 195d5ff..e655bab 100644 --- a/expt/data_test.py +++ b/expt/data_test.py @@ -566,6 +566,9 @@ def test_create_simple(self): assert len(ex.hypotheses) == 1 assert ex.name == "one_hypo" + assert ex._summary_columns is None + assert list(ex._df.columns) == ["hypothesis", "x", "y", "z"] + def test_create_from_dataframe_run(self, runs_gridsearch: RunList): """Tests Experiment.from_dataframe with the minimal defaults.""" @@ -601,6 +604,9 @@ def test_create_from_dataframe_run_multicolumns(self, assert list(ex._df.index.names) == ['algo', 'name'] # Note the order # yapf: enable + # All other columns than by = "algo" + assert ex._summary_columns == ("env_id", "seed") + # by: not exists? with pytest.raises(KeyError): # TODO: Improve exception ex = Experiment.from_dataframe(df, by="unknown", name="Exp.foobar")