Skip to content

Commit

Permalink
update lmms-eval feature for middle checkpoint eval
Browse files Browse the repository at this point in the history
  • Loading branch information
choiszt committed Sep 3, 2024
1 parent 45569f0 commit 0c14015
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 16 deletions.
1 change: 1 addition & 0 deletions lmms_eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .evaluator import evaluate,simple_evaluate
19 changes: 15 additions & 4 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

@positional_deprecated
def simple_evaluate(
tuned_model,
model,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
Expand Down Expand Up @@ -250,7 +251,7 @@ def _adjust_config(task_dict):
verbosity=verbosity,
cli_args=cli_args,
)

print(lm.rank)
if lm.rank == 0:
if isinstance(model, str):
model_name = model
Expand Down Expand Up @@ -572,6 +573,10 @@ def evaluate(
if os.path.exists(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt"):
os.remove(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt")


if not cli_args.output_path.exists():
cli_args.output_path.mkdir(parents=True, exist_ok=True)

if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
Expand Down Expand Up @@ -772,16 +777,22 @@ def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
}
if log_samples:
results_dict["samples"] = dict(samples)

with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f:
f.write(f"rank {int(os.environ.get('RANK', 0))} eval done")
return results_dict

else:
results_dict = None


with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f:
f.write(f"rank {int(os.environ.get('RANK', 0))} eval done")
while len([file for file in os.listdir(cli_args.output_path) if file.endswith("metric_eval_done.txt")]) < lm._world_size:
time.sleep(1)

lm.accelerator.wait_for_everyone()
return results_dict
else:
return None


def request_caching_arg_to_dict(cache_requests: str) -> dict:
Expand Down
25 changes: 14 additions & 11 deletions lmms_eval/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def __init__(
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._rank = self.accelerator.process_index
# print(self.accelerator.local_process_index)
self._world_size = self.accelerator.num_processes

elif accelerator.num_processes == 1 and device_map == "auto":
Expand Down Expand Up @@ -447,11 +448,12 @@ def _collate(x):
placeholder_count = 1

elif type(visual[0]) == PIL.Image.Image: # For image, multi-image tasks
image_tensor = process_images(visual, self._image_processor, self._config)
if type(image_tensor) is list:
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
else:
image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
# image_tensor = process_images(visual, self._image_processor, self._config)
# if type(image_tensor) is list:
# image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
# else:
# image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
image_tensor=None

task_type = "image"
placeholder_count = len(visual) if isinstance(visual, list) else 1
Expand Down Expand Up @@ -537,7 +539,7 @@ def _collate(x):
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
gen_kwargs["modalities"] = ["video"]
gen_kwargs["stopping_criteria"] = [stopping_criteria]
# gen_kwargs["stopping_criteria"] = [stopping_criteria]
self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride
self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode

Expand All @@ -546,11 +548,12 @@ def _collate(x):
if "image_aspect_ratio" in gen_kwargs.keys():
gen_kwargs.pop("image_aspect_ratio")
try:
with torch.inference_mode():
cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
# cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
# with torch.inference_mode():
# cont = self.model.generate(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)
# # cont = self.model.generate(qwen_input_ids, pad_token_id=pad_token_ids, images=image_tensor, use_cache=self.use_cache, **gen_kwargs)

text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
# text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)
text_outputs="hi"
except Exception as e:
raise e

Expand Down
1 change: 1 addition & 0 deletions lmms_eval/tasks/mme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def mme_aggregate_results(results):
for category, question2scores in category2score.items():
total_score = 0
for question_id, scores in question2scores.items():
print(score)
assert len(scores) == 2, "MME only supports pairwise evaluation"
acc = sum(scores) / len(scores) * 100.0
acc_plus = (sum(scores) == 2) * 100.0
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import gc
from itertools import islice

import numpy as np
import pytz
import torch
import transformers
Expand Down

0 comments on commit 0c14015

Please sign in to comment.