From fe8da513987362d681c6b5113fb6e2f5342c7d13 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Mon, 2 Sep 2024 15:59:38 +0000 Subject: [PATCH] fix --- src/lighteval/models/vllm_model.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/lighteval/models/vllm_model.py b/src/lighteval/models/vllm_model.py index 93fde27a..5f51a10d 100644 --- a/src/lighteval/models/vllm_model.py +++ b/src/lighteval/models/vllm_model.py @@ -24,11 +24,8 @@ import os from typing import Optional -import ray from more_itertools import distribute from tqdm import tqdm -from vllm import LLM, SamplingParams -from vllm.transformers_utils.tokenizer import get_tokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.logging.hierarchical_logger import hlog_warn @@ -43,9 +40,20 @@ GreedyUntilRequest, LoglikelihoodRequest, ) +from lighteval.utils.imports import is_vllm_available from lighteval.utils.utils import EnvConfig, as_list +if is_vllm_available(): + import ray + from vllm import LLM, SamplingParams + from vllm.transformers_utils.tokenizer import get_tokenizer +else: + LLM = None + SamplingParams = None + get_tokenizer = None + ray = None + os.environ["TOKENIZERS_PARALLELISM"] = "false" STARTING_BATCH_SIZE = 512 @@ -242,11 +250,12 @@ def greedy_until( num_samples=num_samples, ) + print(f"{len(vllm_outputs)} vllm_outputs") for vllm_output in vllm_outputs: - output_token_ids = vllm_output.outputs[0].token_ids - logprobs = vllm_output.outputs[0].logprobs or [] - logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids, logprobs)] - result = vllm_output.outputs[0].text + output_token_ids = [outputs.token_ids for outputs in vllm_output.outputs] + logprobs = [output.logprobs for output in vllm_output.outputs] or [] + logprobs = [logprob[token_id].logprob for token_id, logprob in zip(output_token_ids[0], logprobs[0])] + result = [output.text for output in vllm_output.outputs] input_token_ids = vllm_output.prompt_token_ids cur_response = GenerativeResponse(