diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 032acc93c..2aa42327a 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -430,7 +430,7 @@ def build_all_requests( if cache_requests and (not cached_instances or rewrite_requests_cache) and limit is not None: limit = None - doc_id_docs = self.doc_iterator(rank=rank, limit=limit, world_size=world_size) + doc_id_docs = utils.create_iterator(enumerate(self.eval_docs_no_media), rank=rank, limit=int(limit) if limit else None, world_size=world_size) doc_iterator_for_counting = itertools.islice(range(len(self.test_docs())), rank, limit, world_size) if self.has_test_docs() else itertools.islice(range(len(self.validation_docs())), rank, limit, world_size) num_docs = sum(1 for _ in doc_iterator_for_counting)