diff --git a/lmms_eval/logging_utils.py b/lmms_eval/logging_utils.py index e4e2947e9..800dfcd1c 100644 --- a/lmms_eval/logging_utils.py +++ b/lmms_eval/logging_utils.py @@ -192,9 +192,15 @@ def make_table(columns: List[str], key: str = "results"): se = dic[m + "_stderr" + "," + f] if se != "N/A": se = "%.4f" % se - table.add_data(*[model_name, model_args, k, version, f, n, m, str(v), str(se)]) + data = [model_name, model_args, k, version, f, n, m, str(v), str(se)] + if key == "groups": + data = [self.group_names] + data + table.add_data(*data) else: - table.add_data(*[model_name, model_args, k, version, f, n, m, str(v), ""]) + data = [model_name, model_args, k, version, f, n, m, str(v), ""] + if key == "groups": + data = [self.group_names] + data + table.add_data(*data) return table