diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py index ef91b48c..1278856c 100644 --- a/runner/app/pipelines/llm.py +++ b/runner/app/pipelines/llm.py @@ -1,11 +1,11 @@ import asyncio import logging import os +import time from typing import Dict, Any, List, Optional, AsyncGenerator, Union from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_max_memory -from torch import cuda -from vllm import LLM, SamplingParams +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams from vllm.outputs import RequestOutput from huggingface_hub import file_download @@ -25,72 +25,110 @@ def __init__(self, model_id: str): use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" max_batch_size = int(os.getenv("MAX_BATCH_SIZE", "4096")) - max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "256")) - mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.90")) + max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "128")) + mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.80")) # Get available GPU memory max_memory = get_max_memory() logger.info(f"Available GPU memory: {max_memory.gpu_memory}") - llm_kwargs = { - "model": self.local_model_path, - "tokenizer": self.local_model_path, - "load_format": "auto", - "trust_remote_code": True, - "dtype": "Bfloat16", # This specifies FP16 precision, TODO: Check GPU capabilities to set best type - "tensor_parallel_size": max_memory.num_gpus, - "max_num_batched_tokens": max_batch_size, - "gpu_memory_utilization": mem_utilization, - "max_num_seqs": max_num_seqs, - } + engine_args = AsyncEngineArgs( + model=self.local_model_path, + tokenizer=self.local_model_path, + trust_remote_code=True, + dtype="auto", # This specifies BFloat16 precision, TODO: Check GPU capabilities to set best type + kv_cache_dtype="auto", # or "fp16" if you want to force it + tensor_parallel_size=max_memory.num_gpus, + max_num_batched_tokens=max_batch_size, + gpu_memory_utilization=mem_utilization, + max_num_seqs=max_num_seqs, + enforce_eager=False, + enable_prefix_caching=True, + seed=42, + swap_space_bytes=4 * 1024 * 1024 * 1024, # 4 GiB instead of 8 GiB + ) if use_8bit: - llm_kwargs["quantization"] = "bitsandbytes" # or another supported 8-bit quantization method - llm_kwargs["load_format"] = "bitsandbytes" + engine_args.quantization = "bitsandbytes" logger.info("Using 8-bit quantization") else: - logger.info("Using FP16 precision") + logger.info("Using BFloat16 precision") - self.llm = LLM(**llm_kwargs) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) logger.info(f"Model loaded: {self.model_id}") logger.info(f"Using GPU memory utilization: {mem_utilization}") + self.engine.start_background_loop() async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + start_time = time.time() + conversation = [] if system_msg: conversation.append({"role": "system", "content": system_msg}) if history: - conversation.extend(history) + for user_msg, assistant_msg in history: + conversation.append({"role": "user", "content": user_msg}) + if assistant_msg: + conversation.append({"role": "assistant", "content": assistant_msg}) conversation.append({"role": "user", "content": prompt}) - # Apply chat template - full_prompt = self.llm.get_tokenizer().apply_chat_template(conversation, tokenize=False) - - max_tokens = kwargs.get("max_tokens", 256) - temperature = kwargs.get("temperature", 0.7) + tokenizer = await self.engine.get_tokenizer() + full_prompt = tokenizer.apply_chat_template(conversation, tokenize=False) sampling_params = SamplingParams( - temperature=temperature, - max_tokens=max_tokens, + temperature=kwargs.get("temperature", 0.7), + max_tokens=kwargs.get("max_tokens", 256), top_p=kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1), ) - async for output in self.llm.generate(prompt=full_prompt, sampling_params=sampling_params, stream=True): - if isinstance(output, RequestOutput): - generated_text = output.outputs[0].text - yield generated_text + request_id = str(time.monotonic()) + results_generator = self.engine.generate(prompt=full_prompt, sampling_params=sampling_params, request_id=request_id) + + generated_tokens = 0 + first_token_time = None + previous_text = "" + + try: + async for request_output in results_generator: + if first_token_time is None: + first_token_time = time.time() + + text = request_output.outputs[0].text + new_text = text[len(previous_text):] + generated_tokens += len(tokenizer.encode(new_text)) + + yield new_text + previous_text = text await asyncio.sleep(0) # Allow other tasks to run - - # Get the final output to calculate total tokens - final_output = await self.llm.generate(prompt=full_prompt, sampling_params=sampling_params) - if isinstance(final_output, RequestOutput): - total_tokens = final_output.prompt_token_ids.shape[1] + len(final_output.outputs[0].token_ids) - yield {"tokens_used": total_tokens} + except Exception as e: + logger.error(f"Error during generation: {e}") + raise + + end_time = time.time() + + # Calculate total tokens and timing + prompt_tokens = len(tokenizer.encode(full_prompt)) + total_tokens = prompt_tokens + generated_tokens + total_time = end_time - start_time + generation_time = end_time - first_token_time if first_token_time else 0 + + # Log benchmarking information + logger.info(f"Generation completed:") + logger.info(f" Total tokens: {total_tokens}") + logger.info(f" Prompt tokens: {prompt_tokens}") + logger.info(f" Generated tokens: {generated_tokens}") + logger.info(f" Total time: {total_time:.2f} seconds") + logger.info(f" Time to first token: {(first_token_time - start_time):.2f} seconds") + logger.info(f" Generation time: {generation_time:.2f} seconds") + logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}") + + yield {"tokens_used": total_tokens} def __str__(self): return f"LLMPipeline(model_id={self.model_id})" + def _find_model_path(self, base_path): # Check if the model files are directly in the base path if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)):