From b7e56068942b6ca56832e0dfcc4f2368e60bcae8 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Wed, 16 Oct 2024 19:00:10 +0200 Subject: [PATCH 1/2] llm: use vLLM --- runner/Dockerfile | 6 +- runner/app/pipelines/llm.py | 261 ++++++-------- runner/app/pipelines/utils/__init__.py | 1 + runner/app/pipelines/utils/utils.py | 20 +- runner/app/routes/llm.py | 6 +- runner/gateway.openapi.yaml | 10 +- runner/openapi.yaml | 10 +- runner/requirements.in | 22 ++ runner/requirements.txt | 455 ++++++++++++++++++++++++- worker/runner.gen.go | 119 +++---- 10 files changed, 667 insertions(+), 243 deletions(-) create mode 100644 runner/requirements.in diff --git a/runner/Dockerfile b/runner/Dockerfile index 5d00e2d2..36aced7e 100644 --- a/runner/Dockerfile +++ b/runner/Dockerfile @@ -29,9 +29,9 @@ RUN pyenv install $PYTHON_VERSION && \ pyenv rehash # Upgrade pip and install your desired packages -ARG PIP_VERSION=23.3.2 +ARG PIP_VERSION=24.2 RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 && \ - pip install --no-cache-dir torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 + pip install --no-cache-dir torch==2.4.0 torchvision torchaudio WORKDIR /app COPY ./requirements.txt /app @@ -48,6 +48,8 @@ ENV MAX_WORKERS=$max_workers ENV HUGGINGFACE_HUB_CACHE=/models ENV DIFFUSERS_CACHE=/models ENV MODEL_DIR=/models +# This ensures compatbility with how GPUs are addresses within go-livepeer +ENV CUDA_DEVICE_ORDER=PCI_BUS_ID COPY app/ /app/app COPY images/ /app/images diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py index 7d3440d7..6f6f5479 100644 --- a/runner/app/pipelines/llm.py +++ b/runner/app/pipelines/llm.py @@ -1,201 +1,140 @@ import asyncio import logging import os -import psutil +import time from typing import Dict, Any, List, Optional, AsyncGenerator, Union - -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig -from accelerate import init_empty_weights, load_checkpoint_and_dispatch from app.pipelines.base import Pipeline -from app.pipelines.utils import get_model_dir, get_torch_device -from huggingface_hub import file_download, snapshot_download -from threading import Thread +from app.pipelines.utils import get_model_dir, get_max_memory +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams +from vllm.outputs import RequestOutput +from huggingface_hub import file_download logger = logging.getLogger(__name__) - -def get_max_memory(): - num_gpus = torch.cuda.device_count() - gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} - cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" - max_memory = {**gpu_memory, "cpu": cpu_memory} - - logger.info(f"Max memory configuration: {max_memory}") - return max_memory - - -def load_model_8bit(model_id: str, **kwargs): - max_memory = get_max_memory() - - quantization_config = BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - ) - - tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - model = AutoModelForCausalLM.from_pretrained( - model_id, - quantization_config=quantization_config, - device_map="auto", - max_memory=max_memory, - offload_folder="offload", - low_cpu_mem_usage=True, - **kwargs - ) - - return tokenizer, model - - -def load_model_fp16(model_id: str, **kwargs): - device = get_torch_device() - max_memory = get_max_memory() - - # Check for fp16 variant - local_model_path = os.path.join( - get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) - has_fp16_variant = any(".fp16.safetensors" in fname for _, _, - files in os.walk(local_model_path) for fname in files) - - if device != "cpu" and has_fp16_variant: - logger.info("Loading fp16 variant for %s", model_id) - kwargs["torch_dtype"] = torch.float16 - kwargs["variant"] = "fp16" - elif device != "cpu": - kwargs["torch_dtype"] = torch.bfloat16 - - tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) - - checkpoint_dir = snapshot_download( - model_id, cache_dir=get_model_dir(), local_files_only=True) - - model = load_checkpoint_and_dispatch( - model, - checkpoint_dir, - device_map="auto", - max_memory=max_memory, - # Adjust based on your model architecture - no_split_module_classes=["LlamaDecoderLayer"], - dtype=kwargs.get("torch_dtype", torch.float32), - offload_folder="offload", - offload_state_dict=True, - ) - - return tokenizer, model - - class LLMPipeline(Pipeline): def __init__(self, model_id: str): self.model_id = model_id - kwargs = { - "cache_dir": get_model_dir(), - "local_files_only": True, - } - self.device = get_torch_device() - - # Generate the correct folder name - folder_path = file_download.repo_folder_name( - repo_id=model_id, repo_type="model") - self.local_model_path = os.path.join(get_model_dir(), folder_path) - self.checkpoint_dir = snapshot_download( - model_id, cache_dir=get_model_dir(), local_files_only=True) - - logger.info(f"Local model path: {self.local_model_path}") - logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") + folder_name = file_download.repo_folder_name(repo_id=model_id, repo_type="model") + base_path = os.path.join(get_model_dir(), folder_name) + + # Find the actual model path + self.local_model_path = self._find_model_path(base_path) + + if not self.local_model_path: + raise ValueError(f"Could not find model files for {model_id}") use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" + max_batch_size = int(os.getenv("MAX_BATCH_SIZE", "4096")) + max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "128")) + mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.80")) + + # Get available GPU memory + max_memory = get_max_memory() + logger.info(f"Available GPU memory: {max_memory.gpu_memory}") + + engine_args = AsyncEngineArgs( + model=self.local_model_path, + tokenizer=self.local_model_path, + trust_remote_code=True, + dtype="auto", # This specifies BFloat16 precision, TODO: Check GPU capabilities to set best type + kv_cache_dtype="auto", # or "fp16" if you want to force it + tensor_parallel_size=max_memory.num_gpus, + max_num_batched_tokens=max_batch_size, + gpu_memory_utilization=mem_utilization, + max_num_seqs=max_num_seqs, + enforce_eager=False, + enable_prefix_caching=True, + ) if use_8bit: + engine_args.quantization = "bitsandbytes" logger.info("Using 8-bit quantization") - self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) else: - logger.info("Using fp16/bf16 precision") - self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) + logger.info("Using BFloat16 precision") - logger.info( - f"Model loaded and distributed. Device map: {self.model.hf_device_map}" - ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) - # Set up generation config - self.generation_config = self.model.generation_config - - self.terminators = [ - self.tokenizer.eos_token_id, - self.tokenizer.convert_tokens_to_ids("<|eot_id|>") - ] - - # Optional: Add optimizations - sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" - if sfast_enabled: - logger.info( - "LLMPipeline will be dynamically compiled with stable-fast for %s", - model_id, - ) - from app.pipelines.optim.sfast import compile_model - self.model = compile_model(self.model) + logger.info(f"Model loaded: {self.model_id}") + logger.info(f"Using GPU memory utilization: {mem_utilization}") + self.engine.start_background_loop() async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + start_time = time.time() + conversation = [] if system_msg: conversation.append({"role": "system", "content": system_msg}) if history: - conversation.extend(history) + for user_msg, assistant_msg in history: + conversation.append({"role": "user", "content": user_msg}) + if assistant_msg: + conversation.append({"role": "assistant", "content": assistant_msg}) conversation.append({"role": "user", "content": prompt}) - input_ids = self.tokenizer.apply_chat_template( - conversation, return_tensors="pt").to(self.model.device) - attention_mask = torch.ones_like(input_ids) + tokenizer = await self.engine.get_tokenizer() + full_prompt = tokenizer.apply_chat_template(conversation, tokenize=False) - max_new_tokens = kwargs.get("max_tokens", 256) - temperature = kwargs.get("temperature", 0.7) - - streamer = TextIteratorStreamer( - self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + sampling_params = SamplingParams( + temperature=kwargs.get("temperature", 0.7), + max_tokens=kwargs.get("max_tokens", 256), + top_p=kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + ) - generate_kwargs = self.generation_config.to_dict() - generate_kwargs.update({ - "input_ids": input_ids, - "attention_mask": attention_mask, - "streamer": streamer, - "max_new_tokens": max_new_tokens, - "do_sample": temperature > 0, - "temperature": temperature, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.eos_token_id, - }) + request_id = str(time.monotonic()) + results_generator = self.engine.generate(prompt=full_prompt, sampling_params=sampling_params, request_id=request_id) - thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) - thread.start() + generated_tokens = 0 + first_token_time = None + previous_text = "" - total_tokens = 0 try: - for text in streamer: - total_tokens += 1 - yield text + async for request_output in results_generator: + if first_token_time is None: + first_token_time = time.time() + + text = request_output.outputs[0].text + new_text = text[len(previous_text):] + generated_tokens += len(tokenizer.encode(new_text)) + + yield new_text + previous_text = text await asyncio.sleep(0) # Allow other tasks to run except Exception as e: - logger.error(f"Error during streaming: {str(e)}") + logger.error(f"Error during generation: {e}") raise - input_length = input_ids.size(1) - yield {"tokens_used": input_length + total_tokens} + end_time = time.time() - def model_generate_wrapper(self, **kwargs): - try: - logger.debug("Entering model.generate") - with torch.cuda.amp.autocast(): # Use automatic mixed precision - self.model.generate(**kwargs) - logger.debug("Exiting model.generate") - except Exception as e: - logger.error(f"Error in model.generate: {str(e)}", exc_info=True) - raise + # Calculate total tokens and timing + prompt_tokens = len(tokenizer.encode(full_prompt)) + total_tokens = prompt_tokens + generated_tokens + total_time = end_time - start_time + generation_time = end_time - first_token_time if first_token_time else 0 + + # Log benchmarking information + logger.info(f"Generation completed:") + logger.info(f" Total tokens: {total_tokens}") + logger.info(f" Prompt tokens: {prompt_tokens}") + logger.info(f" Generated tokens: {generated_tokens}") + logger.info(f" Total time: {total_time:.2f} seconds") + logger.info(f" Time to first token: {(first_token_time - start_time):.2f} seconds") + logger.info(f" Generation time: {generation_time:.2f} seconds") + logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}") + + yield {"tokens_used": total_tokens} def __str__(self): return f"LLMPipeline(model_id={self.model_id})" + + def _find_model_path(self, base_path): + # Check if the model files are directly in the base path + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)): + return base_path + + # If not, look in subdirectories + for root, dirs, files in os.walk(base_path): + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files): + return root + + return None \ No newline at end of file diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index 99e06686..777eb6c8 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -14,4 +14,5 @@ is_numeric, split_prompt, validate_torch_device, + get_max_memory ) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 5c4b8ccd..22cbca50 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -6,6 +6,7 @@ import re from pathlib import Path from typing import Any, Dict, List, Optional +import psutil import numpy as np import torch @@ -37,7 +38,24 @@ def get_torch_device(): return torch.device("mps") else: return torch.device("cpu") - + +class MemoryInfo: + def __init__(self, gpu_memory, cpu_memory, num_gpus): + self.gpu_memory = gpu_memory + self.cpu_memory = cpu_memory + self.num_gpus = num_gpus + + def __repr__(self): + return f"" + +def get_max_memory() -> MemoryInfo: + num_gpus = torch.cuda.device_count() + gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} + cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" + + memory_info = MemoryInfo(gpu_memory=gpu_memory, cpu_memory=cpu_memory, num_gpus=num_gpus) + + return memory_info def validate_torch_device(device_name: str) -> bool: """Checks if the given PyTorch device name is valid and available. diff --git a/runner/app/routes/llm.py b/runner/app/routes/llm.py index 1080280c..e57fd3a3 100644 --- a/runner/app/routes/llm.py +++ b/runner/app/routes/llm.py @@ -38,6 +38,8 @@ async def llm( system_msg: Annotated[str, Form()] = "", temperature: Annotated[float, Form()] = 0.7, max_tokens: Annotated[int, Form()] = 256, + top_p: Annotated[float, Form()] = 1.0, + top_k: Annotated[int, Form()] = -1, history: Annotated[str, Form()] = "[]", # We'll parse this as JSON stream: Annotated[bool, Form()] = False, pipeline: Pipeline = Depends(get_pipeline), @@ -71,7 +73,9 @@ async def llm( history=history_list, system_msg=system_msg if system_msg else None, temperature=temperature, - max_tokens=max_tokens + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k ) if stream: diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 92f8de04..dbabc73c 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.7.0 + version: 0.7.1 servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -541,6 +541,14 @@ components: type: integer title: Max Tokens default: 256 + top_p: + type: number + title: Top P + default: 1.0 + top_k: + type: integer + title: Top K + default: -1 history: type: string title: History diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 57a8b01b..250ff868 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.7.0 + version: 0.7.1 servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -549,6 +549,14 @@ components: type: integer title: Max Tokens default: 256 + top_p: + type: number + title: Top P + default: 1.0 + top_k: + type: integer + title: Top K + default: -1 history: type: string title: History diff --git a/runner/requirements.in b/runner/requirements.in new file mode 100644 index 00000000..c4b004c2 --- /dev/null +++ b/runner/requirements.in @@ -0,0 +1,22 @@ +vllm==0.6.3 +diffusers +accelerate +transformers +fastapi +pydantic +Pillow +python-multipart +uvicorn +huggingface_hub +xformers +triton +peft +deepcache +safetensors +scipy +numpy +av +sentencepiece +protobuf +bitsandbytes +psutil \ No newline at end of file diff --git a/runner/requirements.txt b/runner/requirements.txt index 87f72e43..10581a71 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,21 +1,440 @@ -diffusers==0.30.0 -accelerate==0.30.1 -transformers==4.43.3 -fastapi==0.111.0 -pydantic==2.7.2 -Pillow==10.3.0 -python-multipart==0.0.9 -uvicorn==0.30.0 -huggingface_hub==0.23.2 -xformers==0.0.23 -triton>=2.1.0 -peft==0.11.1 +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile requirements.in +# +accelerate==1.0.1 + # via + # -r requirements.in + # peft +aiohappyeyeballs==2.4.3 + # via aiohttp +aiohttp==3.10.10 + # via + # datasets + # fsspec + # vllm +aiosignal==1.3.1 + # via + # aiohttp + # ray +annotated-types==0.7.0 + # via pydantic +anyio==4.6.2.post1 + # via + # httpx + # openai + # starlette + # watchfiles +attrs==24.2.0 + # via + # aiohttp + # jsonschema + # referencing +av==13.1.0 + # via -r requirements.in +bitsandbytes==0.44.1 + # via -r requirements.in +certifi==2024.8.30 + # via + # httpcore + # httpx + # requests +charset-normalizer==3.4.0 + # via requests +click==8.1.7 + # via + # ray + # uvicorn +cloudpickle==3.1.0 + # via outlines +datasets==3.0.1 + # via outlines deepcache==0.1.1 -safetensors==0.4.3 -scipy==1.13.0 + # via -r requirements.in +diffusers==0.30.3 + # via + # -r requirements.in + # deepcache +dill==0.3.8 + # via + # datasets + # multiprocess +diskcache==5.6.3 + # via outlines +distro==1.9.0 + # via openai +einops==0.8.0 + # via vllm +fastapi==0.115.2 + # via + # -r requirements.in + # vllm +filelock==3.16.1 + # via + # datasets + # diffusers + # huggingface-hub + # ray + # torch + # transformers + # triton + # vllm +frozenlist==1.4.1 + # via + # aiohttp + # aiosignal + # ray +fsspec[http]==2024.6.1 + # via + # datasets + # huggingface-hub + # torch +gguf==0.10.0 + # via vllm +h11==0.14.0 + # via + # httpcore + # uvicorn +httpcore==1.0.6 + # via httpx +httptools==0.6.4 + # via uvicorn +httpx==0.27.2 + # via openai +huggingface-hub==0.25.2 + # via + # -r requirements.in + # accelerate + # datasets + # diffusers + # peft + # tokenizers + # transformers +idna==3.10 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.5.0 + # via + # diffusers + # vllm +interegular==0.3.3 + # via + # lm-format-enforcer + # outlines +jinja2==3.1.4 + # via + # outlines + # torch +jiter==0.6.1 + # via openai +jsonschema==4.23.0 + # via + # mistral-common + # outlines + # ray +jsonschema-specifications==2024.10.1 + # via jsonschema +lark==1.2.2 + # via outlines +llvmlite==0.43.0 + # via numba +lm-format-enforcer==0.10.6 + # via vllm +markupsafe==3.0.1 + # via jinja2 +mistral-common[opencv]==1.4.4 + # via vllm +mpmath==1.3.0 + # via sympy +msgpack==1.1.0 + # via ray +msgspec==0.18.6 + # via vllm +multidict==6.1.0 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +nest-asyncio==1.6.0 + # via outlines +networkx==3.4.1 + # via torch +numba==0.60.0 + # via outlines numpy==1.26.4 -av==12.1.0 -sentencepiece== 0.2.0 -protobuf==5.27.2 -bitsandbytes==0.43.3 + # via + # -r requirements.in + # accelerate + # bitsandbytes + # datasets + # diffusers + # gguf + # mistral-common + # numba + # opencv-python-headless + # outlines + # pandas + # peft + # pyarrow + # scipy + # torchvision + # transformers + # vllm + # xformers +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-ml-py==12.560.30 + # via vllm +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.6.77 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +openai==1.51.2 + # via vllm +opencv-python-headless==4.10.0.84 + # via mistral-common +outlines==0.0.46 + # via vllm +packaging==24.1 + # via + # accelerate + # datasets + # huggingface-hub + # lm-format-enforcer + # peft + # ray + # transformers +pandas==2.2.3 + # via datasets +partial-json-parser==0.2.1.1.post4 + # via vllm +peft==0.13.2 + # via -r requirements.in +pillow==10.4.0 + # via + # -r requirements.in + # diffusers + # mistral-common + # torchvision + # vllm +prometheus-client==0.21.0 + # via + # prometheus-fastapi-instrumentator + # vllm +prometheus-fastapi-instrumentator==7.0.0 + # via vllm +propcache==0.2.0 + # via yarl +protobuf==5.28.2 + # via + # -r requirements.in + # ray + # vllm psutil==6.0.0 + # via + # -r requirements.in + # accelerate + # peft + # vllm +py-cpuinfo==9.0.0 + # via vllm +pyairports==2.1.1 + # via outlines +pyarrow==17.0.0 + # via datasets +pycountry==24.6.1 + # via outlines +pydantic==2.9.2 + # via + # -r requirements.in + # fastapi + # lm-format-enforcer + # mistral-common + # openai + # outlines + # vllm +pydantic-core==2.23.4 + # via pydantic +python-dateutil==2.9.0.post0 + # via pandas +python-dotenv==1.0.1 + # via uvicorn +python-multipart==0.0.12 + # via -r requirements.in +pytz==2024.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # datasets + # gguf + # huggingface-hub + # lm-format-enforcer + # peft + # ray + # transformers + # uvicorn + # vllm +pyzmq==26.2.0 + # via vllm +ray==2.37.0 + # via vllm +referencing==0.35.1 + # via + # jsonschema + # jsonschema-specifications + # outlines +regex==2024.9.11 + # via + # diffusers + # tiktoken + # transformers +requests==2.32.3 + # via + # datasets + # diffusers + # huggingface-hub + # mistral-common + # outlines + # ray + # tiktoken + # transformers + # vllm +rpds-py==0.20.0 + # via + # jsonschema + # referencing +safetensors==0.4.5 + # via + # -r requirements.in + # accelerate + # diffusers + # peft + # transformers +scipy==1.14.1 + # via -r requirements.in +sentencepiece==0.2.0 + # via + # -r requirements.in + # mistral-common + # vllm +six==1.16.0 + # via python-dateutil +sniffio==1.3.1 + # via + # anyio + # httpx + # openai +starlette==0.40.0 + # via + # fastapi + # prometheus-fastapi-instrumentator +sympy==1.13.3 + # via torch +tiktoken==0.7.0 + # via + # mistral-common + # vllm +tokenizers==0.20.1 + # via + # transformers + # vllm +torch==2.4.0 + # via + # accelerate + # bitsandbytes + # deepcache + # peft + # torchvision + # vllm + # xformers +torchvision==0.19.0 + # via vllm +tqdm==4.66.5 + # via + # datasets + # gguf + # huggingface-hub + # openai + # outlines + # peft + # transformers + # vllm +transformers==4.45.2 + # via + # -r requirements.in + # deepcache + # peft + # vllm +triton==3.0.0 + # via + # -r requirements.in + # torch +typing-extensions==4.12.2 + # via + # fastapi + # huggingface-hub + # mistral-common + # openai + # outlines + # pydantic + # pydantic-core + # torch + # vllm +tzdata==2024.2 + # via pandas +urllib3==2.2.3 + # via requests +uvicorn[standard]==0.32.0 + # via + # -r requirements.in + # vllm +uvloop==0.21.0 + # via uvicorn +vllm==0.6.3 + # via -r requirements.in +watchfiles==0.24.0 + # via uvicorn +websockets==13.1 + # via uvicorn +xformers==0.0.27.post2 + # via + # -r requirements.in + # vllm +xxhash==3.5.0 + # via datasets +yarl==1.15.4 + # via aiohttp +zipp==3.20.2 + # via importlib-metadata diff --git a/worker/runner.gen.go b/worker/runner.gen.go index d966ef1b..0666ea14 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -122,6 +122,8 @@ type BodyGenLLM struct { Stream *bool `json:"stream,omitempty"` SystemMsg *string `json:"system_msg,omitempty"` Temperature *float32 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float32 `json:"top_p,omitempty"` } // BodyGenSegmentAnything2 defines model for Body_genSegmentAnything2. @@ -2013,64 +2015,65 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xb+28bt5P/V4i9A5oAkiW7dXMw8P3BSdPEODsNbKVpkRgCtTtaseaSWz4sqTn/7wcO", - "d1fch165xL1vqp/iiOTMZ4bzIjn7KYpllksBwujo7FOk4xlkFP88f3vxUimp3N8J6Fix3DApojM3QsAN", - "EQU6l0IDyWQC/CjqRbmSOSjDAGlkOm0vH82gWJ6B1jQFt84wwyE6i6506v63zN1/tFFMpNHDQy9S8Kdl", - "CpLo7ANSvV0tqYBW6+TkD4hN9NCLnstkOU5BnNuEyZEcwcI4QHWU1A22cb7LuaQJJATHyZRxIEaSCRCj", - "qHAzJ5A47FOpMmqis2jCBFXLQBpk25anF6G+xizxXKfUcrc+6jUgvLZpykRKfqZxoWNy8ROxGhIylarC", - "gdNrWvRTk62q9KIHyuxS2Aa9XmQ0hZHEf9qKTS1LqIhhrGPKoSbrs6PTprAvRSytoinoQlQjSQoCFDVA", - "WIYDMZca+JJwJu4gcTPMDIiBhSG5klluyJMZS2egyD3l1lGiS6IgsXFBgvxpKWdm+TRU16sCJ7lBnJW8", - "wmYTUE5eVgq4xkQ8bSMdcjZdkjkzM4SWsxw4E7DZTrz+OuwE6Y436PG4rcefIFWAYOYzFnsYpR5LpEyT", - "3OoZqnBOVaJxFhPMMMr9nKMmPrJdTVwqHzw22PQ5uZTX5+TJpZz3r6m4I+cJzQ11o0+LjaciIcxoEkvl", - "I0zinGAOLJ0ZNHwvRCGUs33yckGznMMZ+UQ+RpwaEKYfS6GZNiDi5YDHWd+h6+tkwT9GZ+T4aNgjHyMB", - "iv2hBzlbAO9TZfrl6MlDqIBLFOyrOXJLnp18uRcJSKlh9zD2xr8FxGjlJk/0U3QvyxIg8xk17n+wiLlN", - "gEyVzDpUfJEKqZwFTUndIMlHOxx+H5PjEPabAhp566F1obfZ2Pv1OAfVJcNxU4Q3aGpETsuAEMaIHFQh", - "Xg2IzciFn/wWVAsOEwZSb72IR0xBAYpmIK/b8vFwuB5PAkIy7fYYFx6RK6nA/02stpS7qAUUY1YRoopQ", - "VIoysYZoLuegSIXCkUksR8+dLIk2CkRqZi35yvnkBlF3SReqdxer2GST6/dU0ymY5TieQXxXU55RFpra", - "ewvKxURCiV9GcBmaojYsw7g/bcYuFxYsT1weltMpCO2MTCoyoyqbWh7CvPFUXyCYCuxESg5UIFqApK2R", - "GyjcUlGRyIz4+LZGFW5yp77LvappYXj0X2vCtZz6dO6TBJOC0DznbJXkFJR77HfmydCNHNcS2U3JsxWb", - "G3k/LzfQJ7aOAqCW2bdXAL+yBGS7Apg2XOjHXkcxOFU0A43uqyGWIkFl1LLWvSMfSvrzGiufYZKo8Tx9", - "1snVzyRMEAz+egemrz3xLr47FwhVtKKePkbbz6wOvkzy8TD2Tz6ZdLPHExvfgWmiOD551oTxrmTotpi5", - "Hx0op3KaSSuM2wBP0xe3s3r6wT3zgdMNFU7p/sxcpC1WzhnnLjQwgUOtLbzy054j6JpgYSKQTMOY2nS8", - "xomHJ62qphIBFxOaJCvXrQnsiyvyulamFiWqAg3ZhGORtXatL49ErIDqUu5aQkAA5zYl68PB9mR3cvpv", - "nOsOWajUxJwlDes9Hp780BUPceZe4fA90m5zbeSabSnGp44NKeby8qqdWWZMG6mW9dD34TaM1sWMrtBF", - "F2Mj70A0bf7HIFLQBRn5OV2KXRt7dwmdq1psh4rKKKBZjc2Ucg31rE+zbtNaagPZuLqH6cB5g1NI58VL", - "LzKQ5W7/rYJGDHy2IjEKJu1YeXSYg9vmDVZwA2kGwpyLpZkxkZ60TWIiFx2XVYRjGCE/EKoUXZKU3YMg", - "VBNKJnJRXhsU0RZ3tee84Lfff/ud+Jwc2vxzuVh7Tm8zvyizvvbgPzfPU303ZiK3plM+Oe8r0JJbTG1u", - "MsHJDaHMMmcxxmY84FGSK7hn0mr3R8JiXM1MEV16qzsMjI7Hi9eL9+TJ63+9/9fJ6Y8YmG7Or2rV55Xj", - "fIEw/9+dlDPLXSzXd2NpTaXIDVnhwtXjFnorDfraQoGxyhUXrmh3BDXiotmEpdYp06vem5XuETk1INx/", - "Exs7uSZgDKhipZlR4fIOEymHYBtqUpXIyS8eeZefC2dUnP0F41hKlej9xMslE4bgSiaoAV2VURXd1TGE", - "ihTIh2Hv+LYwEVxd8CWwyCE2fvoE/AQF2v3ofvLbl7DMZUwpdL1uKXiRF16GLkFDZm1neLM4KbxcTgup", - "io1o+MJ8BgoI0LiAT5jbOPLkt97vT1c5sHbkxWlNZEFIR2CcToB3ALvE36u6tgatRHNMmEhYjPqnbiqk", - "SlqRFLNd1TesTZnQ+C6c0obr2XbB9WY85jJlZg9r8cs0saLvPEDPJHd1Lpqnp0WY0MbVfnLqIGKMw/EQ", - "3bV3okvPvb3Pu1YQrZywIX+8y6vb03ra+Lvudr9MQLRerOTz7xC3HASenf6DLr120ubh9mvbuWPv26bS", - "OTv89/Vo9HbNQ6Qb2vElMgFDGcfXPs5/mUZnHz5F/6lgGp1F/zFYvYEOigfQQfWo+HDbvrBzpCApODNR", - "XdkdtSQv2AYSr8RZI+uvlLMEyVVSrxOFGcjwp02SNOk9rLB4SVZAMHWiDCHaJoEu3EC5mb0ozb6OVxtq", - "bONV6Jf/rt1a4oSut8rVzduKQQd/jLHXhQm07eS6Zhxr68iOtKC7n6+bTulW77QZV5AwGm6Bf5no2oJW", - "AtShGdUl7lDJ5eVVqJC6bCoYWeXkJrHgFIjn4LFLOuESfzwm7/QuQUAF9ANygUwh5A6J3PFC77XJfm15", - "+Fqzz2EJ1NxnRec9YkVQBa9qdE2e+KVPq7IOi/r6s2G9wKkf6bYaV4seqqAz8cRSrTNW1Md3LnGIKUsw", - "YfrpiBtr4DrLWoD3hLe2ZBTAdDm90OptA/vG/UXf6DjgZm6g3MxYCkOZvw0VwWPJRLoDb119bl17w4We", - "ztts3s/AlHfLnuGcajLlNE0hIVSTNzc/v6+VII7M7mnV7YQb8ZVb+BBQcdzpQs8q3k383fVlcZBYiRBT", - "4SoFGsegtW9WKRm8U3zrrlqcoz0UVFu4n7hdHfvoaqu93BTbNzaF4nhmxXZvQTJ+6s7xGKeH8fiFZ9WM", - "x73IFK1D2xCEOq53z6xRsvGTChlv66s3+YsbLx7o3lJFvbDfagPOl3zXa7W3bHjXO3S0HDpavt2OltN/", - "dEMLuYGcop7xdjjHK1h/W4h3PN/9z3fONLTNc6kKwNUd4uFC4G97iGzF7x0fIttPT+0U2pFntx7IuYxr", - "p3EqlsUNQ9MePrUg3j6EITlGNh3VR/GAt6q9sHG686iGP6ymImYycr9uq0ScHJ5VMTPQ1A6XAPiOu1fh", - "19V50ugfwtagbXVX2Ujj5tZKvz3P5M2Sr+w18iC2nNELqKHOagrp0JivPjtOPDiAhu9iGQYjSgzLQBua", - "5W01rS9OkUDhQUh1e33qxgtOa2iWwy3Cpb4D5Y0qWlv0Z8KJDligSa+olgYxZMVWMbO8cZvplfF6NHr7", - "HKgCVX3RgHHO/1QRmRmTRw+OhjtHduxC0YHnfdJFYWUFOb+orv11WE2xe8gBlBu/tkIgo3tQ2tO6Hx49", - "Oxo61cocBM1ZdBZ9f3R8NHQ7Sc0McQ+wGb9vZL/czlzqrm2tvj4IvkzwD1zF+UPmhUNdJK64bnbzO62D", - "Ns9lgu0T7kQNAhn5PEiVGbhE1E+ooauvQrb5UdenAw/1XXZZD3/wPoFinwyHDRSB2gd/aCfzrhBqRybk", - "3UhlFk/CU8vJalov+uELQljd73bwf04Tcu217/kePw7fd4JaM5OK/QUJMj7+/nEYF8KSl8K4wnAkJbmk", - "KvVaPz59LOlXJRwGe58OHYSTky8KoXXX3gazmkKq+/jTx7K/C2FACcrJDah7UCWCII5i2RJG0A+3D7e9", - "SNsso2pZftJERpKUyYOm2oXvMhu7sL3o6xzoHVC97AuaQV/eg1IsweBfiw69aDDDq328rQCUvR69/M1/", - "9BWDRvi2sGvMeAhVUkBEabA2dTG8elXuDuLnec6X5dNyre8bIzl1JxFX1gTVbiuqN1q0v3JYr3F75Lhe", - "f+04BPb1gf0Q0PYNaL5HbyRJ1aixZ0RjdccIg8B99VVEZxB41fUtwF6+X/bOPo7ve26P7Pv1U9TB9w++", - "/xV8v+pB/zzfLx2jFw04z3ZweDyWW7xdpYRTkVoHpLpYbLm7741e7+Whihf9+XzeR2+3ioOIZeKv9fbz", - "ecfykV09bAA4OPrB0b+coxffFuzp3c6X0amLDpI+LRpN+yfrfbzoSS36FbCtmIoNmbyjh/UrZ/MWx0d2", - "83onyMHRD47+5Ry99L7SuMnJZ/i9bjtILxq4nL3Dyf5Vo2MCa/qgQUJ3RoHgJWrnRL//3Wj9retwiD+4", - "/Tfi9vjm/384w5vA/dDZbfC1SKebFx3rVW4nk2X5UTb2JhpNVh/ldbr8quf9K+f7ktHB3w/+/o34e/C9", - "yJ6ebkNn0AhAI7vGB3vlg+4LLm1CXsgss4KZJXlFDczpMio6cfEZWZ8NBokCmvVTP3rEi+VHsVuOnR9r", - "6N8YfFBZR7YipHHegOZsMAFDB5W8D7cP/xsAAP//D8iZJcZPAAA=", + "H4sIAAAAAAAC/+xbe28bt7L/KsTeCzQBJMty6+bCQP9w0jQxaqeGrTQtEkOgdkcr1lxyy4clNdff/YDD", + "3RX3IUvKSdxzUv0VRyRnfjOcF8nZj1Ess1wKEEZHJx8jHc8go/jn6eXZS6Wkcn8noGPFcsOkiE7cCAE3", + "RBToXAoNJJMJ8IOoF+VK5qAMA6SR6bS9fDSDYnkGWtMU3DrDDIfoJLrQqfvfMnf/0UYxkUb3971IwZ+W", + "KUiik/dI9Wa1pAJarZOTPyA20X0vei6T5TgFcWoTJkdyBAvjANVRUjfYxvk255ImkBAcJ1PGgRhJJkCM", + "osLNnEDisE+lyqiJTqIJE1QtA2mQbVueXoT6GrPEc51Sy936qNeA8NqmKRMp+YnGhY7J2Y/EakjIVKoK", + "B06vadFPTTaq0oseKLNLYQ/o9SyjKYwk/tNWbGpZQkUMYx1TDjVZnx0cN4V9KWJpFU1BF6IaSVIQoKgB", + "wjIciLnUwJeEM3ELiZthZkAMLAzJlcxyQ57MWDoDRe4ot44SXRIFiY0LEuRPSzkzy6ehul4VOMk14qzk", + "FTabgHLyslLANSbiaRvpkLPpksyZmSG0nOXAmYCH7cTrr8NOkO74AT0O23r8EVIFCGY+Y7GHUeqxRMo0", + "ya2eoQrnVCUaZzHBDKPczzlo4iOb1cSl8sHjAZs+Jefy6pQ8OZfz/hUVt+Q0obmhbvRpsfFUJIQZTWKp", + "fIRJnBPMgaUzg4bvhSiEcrZPXi5olnM4IR/Jh4hTA8L0Yyk00wZEvBzwOOs7dH2dLPiH6IQMDw575EMk", + "QLE/9CBnC+B9qky/HD26DxVwjoJ9MUduybOVL/ciASk17A7G3vg3gBit3OSJforuZVkCZD6jxv0PFjG3", + "CZCpklmHis9SIZWzoCmpGyT5YA8Pv43JMIT9poBGLj20LvQ2G3u/HuegumQYNkV4g6ZG5LQMCGGMyEEV", + "4tWA2Iyc+cmXoFpwmDCQeutFPGIKClA0A3ndloeHh+vxJCAk026PceEBuZAK/N/Eaku5i1pAMWYVIaoI", + "RaUoE2uI5nIOilQoHJnEcvTcyZJoo0CkZtaSr5xPrhF1l3Sherexiodscv2eajoFsxzHM4hva8ozykJT", + "e5egXEwklPhlBJehKWrDMoz702bscmHB8sTlYTmdgtDOyKQiM6qyqeUhzGtP9QWCqcBOpORABaIFSNoa", + "uYbCLRUVicyIj29rVOEmd+q73KuaFg4P/m9NuJZTn859kmBSEJrnnK2SnIJyj/3OPDl0I8NaIrsuebZi", + "cyPv5+UG+sTWUQDUMvvmCuBXloBsVwDThgt93+soBqeKZqDRfTXEUiSojFrWunPkQ0l/WmPlM0wSNZ7H", + "zzq5+pmECYLBX2/B9LUn3sV36wKhilbU08do+4nVwedJPh7G7sknk272eGLjWzBNFMOjZ00Yb0uGbouZ", + "+9GBciqnmbTCuA3wNH1xO6unH9wzHzjdUOGU7s/MRdpi5Zxx7kIDEzjU2sILP+05gq4JFiYCyTSMqU3H", + "a5z48KhV1VQi4GJCk2TlujWBfXFFXtfK1KJEVaAhm3Asstau9eWRiBVQXcpdSwgI4NSmZH042Jzsjo7/", + "i3PdPguVmpizpGG9w8Oj77riIc7cKRy+Q9ptro1csynF+NTxQIo5P79oZ5YZ00aqZT30vb8Jo3Uxoyt0", + "0cXYyFsQTZv/PogUdEFGfk6XYtfG3m1C56oW26KiMgpoVmMzpVxDPevTrNu0ltpANq7uYTpwXuMU0nnx", + "0osMZLnbf6ugEQOfrUiMgkkdocbIfFx3w/4wWCxz8nOnit26vHkeCJddblvndBifM6oHbO4a0gyEORVL", + "M2MiPWob4EQuOq7GCMegRb4jVCm6JCm7A0GoJpRM5KK8pChiO9pQz/ncb7//9jvxFUDoYc/lYu2tQJv5", + "WVljaA/+U6sKqm/HTOTWdMon530FWnKLidRNJji5IZRZ5izGTIDHSUpyBXdMWu3+SFiMq5kpYllvdWOC", + "sXi4eL14R568/uHdD0fH32MYvD69qNW6F47zGcL8jzuXZ5a7zKFvx9KaSpEP5KAzV/1b6K006CsZBcYq", + "V8q4I4IjqBEXzSYstU6ZXvXerHSPyKkB4f6b2NjJNQFjQBUrzYwKl+WYSDkE21CTqkROfvHIu6KKcEbF", + "2V8wjqVUid5NvFwyYQiuZIIa0FXRVtFdHXqoSIG8P+wNbwoTwdUFXwKLHGLjp0/AT1Cg3Y/uJ799Cctc", + "fpZC16ukghd54WXoEjRk1naGN4ujwsvltJCq2IiGL8xnoIAAjQv4hLmNI09+6/3+dJVxawdsnNZEFiQQ", + "BMbpBHgHsHP8vaqia9BKNEPCRMJi1D91UyFV0oqkmO1qzMPalAmNb8MpbbiebRdcb8ZjLlNmdrAWv0wT", + "K/rOA/RMcldVo3l6WoQJbVylKacOIsY4HA/RXXknOvfc2/u8bb3SygkP5I+3eXVXW08bf9dN8ucJiNaL", + "lXz6jeWGY8ez43/QFdtW2tzftW065ex8t1U6Z4f/vh6NLtc8e7qhLd89EzCUcXxb5PyXaXTy/mP0vwqm", + "0Un0P4PVi+ugeG4dVE+Y9zft60FHCpKCMxPVBeFBS/KCbSDxSpw1sv5KOUuQXCX1OlGYgQx/ekiSJr37", + "FRYvyQoIpk6UIUTbJNCFGyg3sxel2dfxakONbbxB/fJz7Y4UJ3S9jK7u+VYMOvhjjL0qTKBtJ1c141hb", + "R3akBd39WN50Srd6q824gITRcAv8O0jXFrQSoA7NqC5xh0rOzy9ChdRlU8HIKic3iQVnTjx1j13SCZf4", + "wzh5q7cJAiqgH5ALZAohd0jkjhd6p032a8vD15p9Dkug5j4rOu8RK4IqeFWja/LEL31alXVY1NcfKesF", + "Tv1It9G4WvRQBZ2JJ5ZqnbGiPr5xiUNMWYIJ009H3FgD11nWArwnvLEBpACmy+mFVm8a2B/cX/SNjgNu", + "5gbKzYylMJT5u1cRPM1MpDvw1tXn1rU3XOjpvM3m3QxMeZPtGc6pJlNO0xQSQjV5c/3Tu1oJ4shsn1bd", + "TrgRX7mFzw4Vx62uD63i3cTfXp0XB4mVCDEVrlKgcQxa+9aYksFbxTfuqsU52kNBtYX7idvVsY+uttrJ", + "TbFZ5KFQHM+s2OwtSMZP3Toe4/QwHr/wrJrxuBeZolFpE4JQx/VenTVKNn5SIeNNffVD/uLGi+fAS6qo", + "F/Zrbff5nK+IrWaaB14R9/0z+/6Zr7d/5vgf3T5DriGnqGe8Hc7xCtbfFuIdzzf//40zDW3zXKoCcHWH", + "uL8Q+NuePVvxe8tnz/bTUzuFduTZjQdyLuPaaZyKZXHD0LSHjy2IN/dhSI6RTUf1UTwXrmovbNPuPKrh", + "D6upiJmM3K+bKhEnh2dVzAw0tcUlAL4a71T4dfW5NLqVsBFpU91Vtu24ubXSb8czebPkKzubPIgNZ/QC", + "aqizmkI6NOarz44TDw6g4btYhsGIEsMy0IZmeVtN64tTJFB4EFLdXJ+68YLTGprlcItwqe9AeaOK1gb9", + "mXCiAxZo0iuqpUEMWbFVzCyv3WZ6ZbwejS6fA1Wgqu8nMM75nyoiM2Py6N7RcOfIjl0o+v28T7oorKwg", + "p2fVtb8Oqyl2BzmAcuNXVghkdAdKe1qHB88Ohk6zMgdBcxadRN8eDA8O3UZSM0PYA+z87xvZL3czl7pr", + "V6tPHYLPIPz7VnH8kHnhT2eJq62bnw44pYM2z2WCvRruQA0CGfk0SJUZuDzUT6ihq09QNrlR13cK9/VN", + "dkkPf/AugWIfHR42UARaH/yhnczbQqidmJB3I5NZPAhPLSerab3ou88IYXW928H/OU3Ilde+5zt8HL5v", + "BbVmJhX7CxJkPPz2cRgXwpKXwri6cCQlOacq9VofHj+W9KsKDmO9z4YOwtHRZ4XQumpvg1lNIdV1/PFj", + "2d+ZMKAE5eQa1B2oEkEQRrFqCQPo+5v7m16kbZZRtSy/nyIjScrcQVPtoneZjF3UXvR1DvQWqF72Bc2g", + "L+9AKZZg7K9Fh140mOHNPl5WAMpej17+4j/6gkEjfFrYNmbchyopIKI0WJq6GF49KncH8dM858vyZbnW", + "ZI6RnLqDiKtqgmK3FdUb/eBfOKzXuD1yXK8/duwD+/rAvg9ouwY036I3kqTq09gxorG6Y4RB4K76BKMz", + "CLzq+vBgJ98vG3Ufx/c9t0f2/fohau/7e9//Ar5fNbx/mu+XjtGLBpxnWzg8nsotXq5SwqlIrQNS3Su2", + "3N23Rq/38lDFi/58Pu+jt1vFQcQy8bd6u/m8Y/nIrh6+/+8dfe/on8/Ri08LdvRu58vo1EUDSZ8Wfab9", + "o/U+XrSkFu0K2FVMxQOZvKOF9Qtn8xbHR3bzeiPI3tH3jv75HL30vtK4ydEn+L1uO0gvGricvcXJ/lWj", + "YQJr+qA/QndGgeAhautEv/vdaP2pa3+I37v9V+L2+OT/b5zhTeB+6Ow2+Fik082LhvUqt5PJsvwCHFsT", + "jSarb/I6XX7V8v6F833JaO/ve3//Svw9+FxkR0+3oTNoBKCRXeN7vfI99wWXNiEvZJZZwcySvKIG5nQZ", + "FY24+IqsTwaDRAHN+qkfPeDF8oPYLcfGjzX0rw0+qKwjWxHSOG9AczaYgKGDSt77m/t/BQAA//8QlnQv", + "M1AAAA==", } // GetSwagger returns the content of the embedded swagger specification file From 6f9256de251246eb1533cd501f1166c0015530d4 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Fri, 18 Oct 2024 18:56:09 +0000 Subject: [PATCH 2/2] fixup: additional params --- runner/Dockerfile | 2 +- runner/app/pipelines/llm.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/runner/Dockerfile b/runner/Dockerfile index 36aced7e..60c19dff 100644 --- a/runner/Dockerfile +++ b/runner/Dockerfile @@ -49,7 +49,7 @@ ENV HUGGINGFACE_HUB_CACHE=/models ENV DIFFUSERS_CACHE=/models ENV MODEL_DIR=/models # This ensures compatbility with how GPUs are addresses within go-livepeer -ENV CUDA_DEVICE_ORDER=PCI_BUS_ID +# ENV CUDA_DEVICE_ORDER=PCI_BUS_ID COPY app/ /app/app COPY images/ /app/images diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py index 6f6f5479..ab3538e3 100644 --- a/runner/app/pipelines/llm.py +++ b/runner/app/pipelines/llm.py @@ -13,6 +13,7 @@ class LLMPipeline(Pipeline): def __init__(self, model_id: str): + logger.info("Initializing LLM pipeline") self.model_id = model_id folder_name = file_download.repo_folder_name(repo_id=model_id, repo_type="model") base_path = os.path.join(get_model_dir(), folder_name) @@ -24,13 +25,20 @@ def __init__(self, model_id: str): raise ValueError(f"Could not find model files for {model_id}") use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" - max_batch_size = int(os.getenv("MAX_BATCH_SIZE", "4096")) + max_num_batched_tokens = int(os.getenv("MAX_NUM_BATCHED_TOKENS", "8192")) max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "128")) - mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.80")) + max_model_len = int(os.getenv("MAX_MODEL_LEN", "8192")) + mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.85")) + + if max_num_batched_tokens < max_model_len: + max_num_batched_tokens = max_model_len + logger.info(f"max_num_batched_tokens ({max_num_batched_tokens}) is smaller than max_model_len ({max_model_len}). This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences.") + logger.info(f"setting 'max_model_len' to equal 'max_num_batched_tokens'") # Get available GPU memory max_memory = get_max_memory() logger.info(f"Available GPU memory: {max_memory.gpu_memory}") + logger.info(f"Tensor parallel size: {max_memory.num_gpus}") engine_args = AsyncEngineArgs( model=self.local_model_path, @@ -39,11 +47,12 @@ def __init__(self, model_id: str): dtype="auto", # This specifies BFloat16 precision, TODO: Check GPU capabilities to set best type kv_cache_dtype="auto", # or "fp16" if you want to force it tensor_parallel_size=max_memory.num_gpus, - max_num_batched_tokens=max_batch_size, + max_num_batched_tokens=max_num_batched_tokens, gpu_memory_utilization=mem_utilization, max_num_seqs=max_num_seqs, enforce_eager=False, enable_prefix_caching=True, + max_model_len=max_model_len ) if use_8bit: