Skip to content

Commit fefc964

Browse files
authored
[Wandb Logger] add models, and args to wandb tables. (EvolvingLMMs-Lab#55)
* Refactor logging in lmms_eval package * Refactor variable names in lmms_eval package
1 parent 7155c41 commit fefc964

File tree

6 files changed

+21
-10
lines changed

6 files changed

+21
-10
lines changed

lmms_eval/__main__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
201201
results_list.append(results)
202202

203203
accelerator.wait_for_everyone()
204-
if is_main_process:
204+
if is_main_process and args.wandb_args:
205205
wandb_logger.post_init(results)
206206
wandb_logger.log_eval_result()
207207
if args.wandb_log_samples and samples is not None:
208208
wandb_logger.log_eval_samples(samples)
209209

210+
wandb_logger.finish()
211+
210212
except Exception as e:
211213
traceback.print_exc()
212214
eval_logger.error(f"Error during evaluation: {e}")
@@ -312,7 +314,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
312314
for task_name, config in results["configs"].items():
313315
filename = args.output_path.joinpath(f"{task_name}.json")
314316
# Structure the data with 'args' and 'logs' keys
315-
data_to_dump = {"args": vars(args), "config": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"])} # Convert Namespace to dict
317+
data_to_dump = {"args": vars(args), "model_configs": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"])} # Convert Namespace to dict
316318
samples_dumped = json.dumps(data_to_dump, indent=4, default=_handle_non_serializable)
317319
filename.open("w").write(samples_dumped)
318320
eval_logger.info(f"Saved samples to {filename}")

lmms_eval/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def simple_evaluate(
134134

135135
if lm.rank == 0:
136136
# add info about the model and few shot config
137-
results["config"] = {
137+
results["model_configs"] = {
138138
"model": model if isinstance(model, str) else model.model.config._name_or_path,
139139
"model_args": model_args,
140140
"batch_size": batch_size,

lmms_eval/logging_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def __init__(self, args):
8282
os.environ["WANDB_MODE"] = "offline"
8383
self.init_run()
8484

85+
def finish(self):
86+
self.run.finish()
87+
8588
@tenacity.retry(wait=tenacity.wait_fixed(5), stop=tenacity.stop_after_attempt(5))
8689
def init_run(self):
8790
if "name" not in self.wandb_args:
@@ -152,6 +155,9 @@ def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
152155
def _log_results_as_table(self) -> None:
153156
"""Generate and log evaluation results as a table to W&B."""
154157
columns = [
158+
"Model",
159+
"Args",
160+
"Tasks",
155161
"Version",
156162
"Filter",
157163
"num_fewshot",
@@ -164,6 +170,9 @@ def make_table(columns: List[str], key: str = "results"):
164170
table = wandb.Table(columns=columns)
165171
results = copy.deepcopy(self.results)
166172

173+
model_name = results.get("model_configs").get("model")
174+
model_args = results.get("model_configs").get("model_args")
175+
167176
for k, dic in results.get(key).items():
168177
if k in self.group_names and not key == "groups":
169178
continue
@@ -183,14 +192,14 @@ def make_table(columns: List[str], key: str = "results"):
183192
se = dic[m + "_stderr" + "," + f]
184193
if se != "N/A":
185194
se = "%.4f" % se
186-
table.add_data(*[k, version, f, n, m, str(v), str(se)])
195+
table.add_data(*[model_name, model_args, k, version, f, n, m, str(v), str(se)])
187196
else:
188-
table.add_data(*[k, version, f, n, m, str(v), ""])
197+
table.add_data(*[model_name, model_args, k, version, f, n, m, str(v), ""])
189198

190199
return table
191200

192201
# log the complete eval result to W&B Table
193-
table = make_table(["Tasks"] + columns, "results")
202+
table = make_table(columns, "results")
194203
self.run.log({"evaluation/eval_results": table})
195204

196205
if "groups" in self.results.keys():
@@ -209,7 +218,7 @@ def log_eval_result(self) -> None:
209218
"""Log evaluation results to W&B."""
210219
# Log configs to wandb
211220
configs = self._get_config()
212-
self.run.config.update(configs)
221+
self.run.config.update(configs, allow_val_change=True)
213222

214223
wandb_summary, self.wandb_results = self._sanitize_results_dict()
215224
# update wandb.run.summary with items that were removed

lmms_eval/tasks/internal_eval/d170_cn.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ dataset_kwargs:
44
task: "d170_cn"
55
test_split: test
66
output_type: generate_until
7-
doc_to_visual: !function utils.doc_to_visual
87
doc_to_text: !function utils.doc_to_text # Such that {{prompt}} will be replaced by doc["question"]
8+
doc_to_visual: !function d170_cn_utils.doc_to_visual
99
doc_to_target: "{{annotation}}"
1010
generation_kwargs:
1111
until:

lmms_eval/tasks/mme/mme_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ generation_kwargs:
1414
num_beams: 1
1515
do_sample: false
1616
# The return value of process_results will be used by metrics
17-
process_results: !function utils.mme_process_result
17+
process_results: !function utils.mme_process_results
1818
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
1919
metric_list:
2020
- metric: mme_percetion_score

lmms_eval/tasks/mme/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def parse_pred_ans(pred_ans):
6767
return pred_label
6868

6969

70-
def mme_process_result(doc, results):
70+
def mme_process_results(doc, results):
7171
"""
7272
Args:
7373
doc: a instance of the eval dataset

0 commit comments

Comments
 (0)