Skip to content

Commit

Permalink
Refactor caching module and fix serialization issue (#255)
Browse files Browse the repository at this point in the history
* Refactor caching module for LM evaluation harness

* feat: Fix issue with serializing non-serializable objects in caching module

* chore: Add error logging and retry mechanism in get_chat_response function
  • Loading branch information
Luodian authored Sep 15, 2024
1 parent e738871 commit 10be8c3
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
**/.cache
env
*.pyc
output/
Expand Down
9 changes: 9 additions & 0 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,15 @@ def build_all_requests(
if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances)

# FIXME: Bo - We need to check if the doc_to_visual if it's exists and restore it. If we use cache, the doc_to_visual will be None since it's not serializable
for instance in self._instances:
if instance.arguments[2] is None:
arguments = (instance.arguments[0], instance.arguments[1], self.doc_to_visual, *instance.arguments[3:])
else:
arguments = instance.arguments

instance.arguments = arguments

@abc.abstractmethod
def construct_requests(self, doc_id, ctx, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Expand Down
14 changes: 13 additions & 1 deletion lmms_eval/caching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dill

from lmms_eval.loggers.utils import _handle_non_serializable
from lmms_eval.utils import eval_logger

MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -39,9 +40,20 @@ def save_to_cache(file_name, obj):

file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"

serializable_obj = []

for item in obj:
sub_serializable_obj = []
for subitem in item:
if hasattr(subitem, "arguments"): # we need to handle the arguments specially since doc_to_visual is callable method and not serializable
serializable_arguments = tuple(arg if not callable(arg) else None for arg in subitem.arguments)
subitem.arguments = serializable_arguments
sub_serializable_obj.append(_handle_non_serializable(subitem))
serializable_obj.append(sub_serializable_obj)

eval_logger.debug(f"Saving {file_path} to cache...")
with open(file_path, "wb") as file:
file.write(dill.dumps(obj))
file.write(dill.dumps(serializable_obj))


# NOTE the "key" param is to allow for flexibility
Expand Down
14 changes: 7 additions & 7 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,13 @@ def evaluate(
limit=limit,
rank=lm.rank,
world_size=lm.world_size,
# cache_requests=cache_requests, # later we will add them
# rewrite_requests_cache=rewrite_requests_cache,
# system_instruction=system_instruction,
# apply_chat_template=apply_chat_template,
# fewshot_as_multiturn=fewshot_as_multiturn,
# chat_template=getattr(lm, "apply_chat_template") if apply_chat_template else None,
# tokenizer_name=getattr(lm, "tokenizer_name", "") if apply_chat_template else "",
cache_requests=cache_requests, # later we will add them
rewrite_requests_cache=rewrite_requests_cache,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=getattr(lm, "apply_chat_template") if apply_chat_template else None,
tokenizer_name=getattr(lm, "tokenizer_name", "") if apply_chat_template else "",
)
eval_logger.debug(f"Task: {task_output.task_name}; number of requests on this rank: {len(task._instances)}")
if write_out:
Expand Down
8 changes: 5 additions & 3 deletions lmms_eval/tasks/wild_vision_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ def get_chat_response(base64_image, prompt, max_retries=5, wait_time=10):
# print(response_data)
return response_data["choices"][0]["message"]["content"], GPT_EVAL_MODEL_NAME
except requests.exceptions.RequestException as e:
print(f"Request failed on attempt {attempt+1}: {e}")
eval_logger.error(f"Request failed on attempt {attempt+1}: {e}")
time.sleep(wait_time)
if attempt == max_retries - 1:
print(f"Failed to get response after {max_retries} attempts")
eval_logger.error(f"Failed to get response after {max_retries} attempts")
return "", GPT_EVAL_MODEL_NAME
except Exception as e:
print(f"Error on attempt {attempt+1}: {e}")
eval_logger.error(f"Error on attempt {attempt+1}: {e}")
time.sleep(wait_time)
return "", GPT_EVAL_MODEL_NAME


Expand Down

0 comments on commit 10be8c3

Please sign in to comment.