diff --git a/.github/workflows/ai-runner-pipelines-docker.yaml b/.github/workflows/ai-runner-pipelines-docker.yaml index 5ca5d73e..9b38adbd 100644 --- a/.github/workflows/ai-runner-pipelines-docker.yaml +++ b/.github/workflows/ai-runner-pipelines-docker.yaml @@ -34,6 +34,7 @@ jobs: - docker/Dockerfile.segment_anything_2 - docker/Dockerfile.text_to_speech - docker/Dockerfile.audio_to_text + - docker/Dockerfile.llm steps: - name: Check out code uses: actions/checkout@v4.1.1 diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py index 7d3440d7..4e240139 100644 --- a/runner/app/pipelines/llm.py +++ b/runner/app/pipelines/llm.py @@ -1,201 +1,309 @@ import asyncio import logging import os -import psutil -from typing import Dict, Any, List, Optional, AsyncGenerator, Union +import time +from dataclasses import dataclass +from typing import Dict, Any, List, AsyncGenerator, Union, Optional -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig -from accelerate import init_empty_weights, load_checkpoint_and_dispatch from app.pipelines.base import Pipeline -from app.pipelines.utils import get_model_dir, get_torch_device -from huggingface_hub import file_download, snapshot_download -from threading import Thread +from app.pipelines.utils import get_model_dir, get_max_memory +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams +from huggingface_hub import file_download +from transformers import AutoConfig logger = logging.getLogger(__name__) -def get_max_memory(): - num_gpus = torch.cuda.device_count() - gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} - cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" - max_memory = {**gpu_memory, "cpu": cpu_memory} +@dataclass +class GenerationConfig: + max_tokens: int = 256 + temperature: float = 0.7 + top_p: float = 1.0 + top_k: int = -1 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + + def validate(self): + """Validate generation parameters""" + if not 0 <= self.temperature <= 2: + raise ValueError("Temperature must be between 0 and 2") + if not 0 <= self.top_p <= 1: + raise ValueError("Top_p must be between 0 and 1") + if self.max_tokens < 1: + raise ValueError("Max_tokens must be positive") + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError("Presence penalty must be between -2.0 and 2.0") + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError("Frequency penalty must be between -2.0 and 2.0") - logger.info(f"Max memory configuration: {max_memory}") - return max_memory +class LLMPipeline(Pipeline): + def __init__( + self, + model_id: str, + ): + """ + Initialize the LLM Pipeline. + + Args: + model_id: The identifier for the model to load + use_8bit: Whether to use 8-bit quantization + max_batch_size: Maximum batch size for inference + max_num_seqs: Maximum number of sequences + mem_utilization: GPU memory utilization target + max_num_batched_tokens: Maximum number of batched tokens + """ + logger.info("Initializing LLM pipeline") -def load_model_8bit(model_id: str, **kwargs): - max_memory = get_max_memory() - - quantization_config = BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - ) - - tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - model = AutoModelForCausalLM.from_pretrained( - model_id, - quantization_config=quantization_config, - device_map="auto", - max_memory=max_memory, - offload_folder="offload", - low_cpu_mem_usage=True, - **kwargs - ) - - return tokenizer, model - - -def load_model_fp16(model_id: str, **kwargs): - device = get_torch_device() - max_memory = get_max_memory() - - # Check for fp16 variant - local_model_path = os.path.join( - get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) - has_fp16_variant = any(".fp16.safetensors" in fname for _, _, - files in os.walk(local_model_path) for fname in files) - - if device != "cpu" and has_fp16_variant: - logger.info("Loading fp16 variant for %s", model_id) - kwargs["torch_dtype"] = torch.float16 - kwargs["variant"] = "fp16" - elif device != "cpu": - kwargs["torch_dtype"] = torch.bfloat16 - - tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) - - checkpoint_dir = snapshot_download( - model_id, cache_dir=get_model_dir(), local_files_only=True) + self.model_id = model_id + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model") + base_path = os.path.join(get_model_dir(), folder_name) - model = load_checkpoint_and_dispatch( - model, - checkpoint_dir, - device_map="auto", - max_memory=max_memory, - # Adjust based on your model architecture - no_split_module_classes=["LlamaDecoderLayer"], - dtype=kwargs.get("torch_dtype", torch.float32), - offload_folder="offload", - offload_state_dict=True, - ) + # Find the actual model path + self.local_model_path = self._find_model_path(base_path) - return tokenizer, model + if not self.local_model_path: + raise ValueError(f"Could not find model files for {model_id}") + use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" + max_num_batched_tokens = int(os.getenv("MAX_NUM_BATCHED_TOKENS", "8192")) + max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "128")) + max_model_len = int(os.getenv("MAX_MODEL_LEN", "8192")) + mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.85")) + tensor_parallel_size = int(os.getenv("TENSOR_PARALLEL_SIZE", "1")) + pipeline_parallel_size = int(os.getenv("PIPELINE_PARALLEL_SIZE", "1")) + + if max_num_batched_tokens < max_model_len: + max_num_batched_tokens = max_model_len + logger.info( + f"max_num_batched_tokens ({max_num_batched_tokens}) is smaller than max_model_len ({max_model_len}). This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences.") + logger.info(f"setting 'max_model_len' to equal 'max_num_batched_tokens'") -class LLMPipeline(Pipeline): - def __init__(self, model_id: str): - self.model_id = model_id - kwargs = { - "cache_dir": get_model_dir(), - "local_files_only": True, - } - self.device = get_torch_device() - - # Generate the correct folder name - folder_path = file_download.repo_folder_name( - repo_id=model_id, repo_type="model") - self.local_model_path = os.path.join(get_model_dir(), folder_path) - self.checkpoint_dir = snapshot_download( - model_id, cache_dir=get_model_dir(), local_files_only=True) + # Load config to check model compatibility + try: + config = AutoConfig.from_pretrained(self.local_model_path) + num_heads = config.num_attention_heads + num_layers = config.num_hidden_layers + logger.info( + f"Model has {num_heads} attention heads and {num_layers} layers") + + # Validate tensor parallelism + if num_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({num_heads}) must be divisible " + f"by tensor parallel size ({tensor_parallel_size})." + ) + + # Validate pipeline parallelism + if num_layers < pipeline_parallel_size: + raise ValueError( + f"Pipeline parallel size ({pipeline_parallel_size}) cannot be larger " + f"than number of layers ({num_layers})." + ) + + # Validate total GPU requirement + total_gpus_needed = tensor_parallel_size * pipeline_parallel_size + max_memory = get_max_memory() + if total_gpus_needed > max_memory.num_gpus: + raise ValueError( + f"Total GPUs needed ({total_gpus_needed}) exceeds available GPUs " + f"({max_memory.num_gpus}). Reduce tensor_parallel_size ({tensor_parallel_size}) " + f"or pipeline_parallel_size ({pipeline_parallel_size})." + ) + + logger.info(f"Using tensor parallel size: {tensor_parallel_size}") + logger.info(f"Using pipeline parallel size: {pipeline_parallel_size}") + logger.info(f"Total GPUs used: {total_gpus_needed}") - logger.info(f"Local model path: {self.local_model_path}") - logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") + except Exception as e: + logger.error(f"Error in parallelism configuration: {e}") + raise - use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" + engine_args = AsyncEngineArgs( + model=self.local_model_path, + tokenizer=self.local_model_path, + trust_remote_code=True, + dtype="auto", # This specifies BFloat16 precision, TODO: Check GPU capabilities to set best type + kv_cache_dtype="auto", # or "fp16" if you want to force it + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_num_batched_tokens=max_num_batched_tokens, + gpu_memory_utilization=mem_utilization, + max_num_seqs=max_num_seqs, + enforce_eager=False, + enable_prefix_caching=True, + max_model_len=max_model_len + ) if use_8bit: + engine_args.quantization = "bitsandbytes" + engine_args.load_format = "bitsandbytes" logger.info("Using 8-bit quantization") - self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) else: - logger.info("Using fp16/bf16 precision") - self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) + logger.info("Using BFloat16 precision") + + self.engine_args = engine_args + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + logger.info(f"Model loaded: {self.model_id}") + logger.info(f"Using GPU memory utilization: {mem_utilization}") + self.engine.start_background_loop() + + @staticmethod + def _get_model_dir() -> str: + """Get the model directory from environment or default""" + return os.getenv("MODEL_DIR", "/models") + + def validate_messages(self, messages: List[Dict[str, str]]): + """Validate message format""" + if not messages: + raise ValueError("Messages cannot be empty") + + for msg in messages: + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError(f"Invalid message format: {msg}") + if msg['role'] not in ['system', 'user', 'assistant']: + raise ValueError(f"Invalid role in message: {msg['role']}") + + async def generate( + self, + messages: List[Dict[str, str]], + generation_config: Optional[GenerationConfig] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + """Internal generation method""" + start_time = time.time() + config = generation_config or GenerationConfig() + tokenizer = await self.engine.get_tokenizer() - logger.info( - f"Model loaded and distributed. Device map: {self.model.hf_device_map}" + try: + full_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False + ) + except Exception as e: + logger.error(f"Error applying chat template: {e}") + raise + + sampling_params = SamplingParams( + temperature=config.temperature, + max_tokens=config.max_tokens, + top_p=config.top_p, + top_k=config.top_k, + presence_penalty=config.presence_penalty, + frequency_penalty=config.frequency_penalty, ) - # Set up generation config - self.generation_config = self.model.generation_config + request_id = f"chatcmpl-{int(time.time())}" - self.terminators = [ - self.tokenizer.eos_token_id, - self.tokenizer.convert_tokens_to_ids("<|eot_id|>") - ] + results_generator = self.engine.generate( + prompt=full_prompt, sampling_params=sampling_params, request_id=request_id) - # Optional: Add optimizations - sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" - if sfast_enabled: - logger.info( - "LLMPipeline will be dynamically compiled with stable-fast for %s", - model_id, - ) - from app.pipelines.optim.sfast import compile_model - self.model = compile_model(self.model) - - async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: - conversation = [] - if system_msg: - conversation.append({"role": "system", "content": system_msg}) - if history: - conversation.extend(history) - conversation.append({"role": "user", "content": prompt}) - - input_ids = self.tokenizer.apply_chat_template( - conversation, return_tensors="pt").to(self.model.device) - attention_mask = torch.ones_like(input_ids) - - max_new_tokens = kwargs.get("max_tokens", 256) - temperature = kwargs.get("temperature", 0.7) - - streamer = TextIteratorStreamer( - self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) - - generate_kwargs = self.generation_config.to_dict() - generate_kwargs.update({ - "input_ids": input_ids, - "attention_mask": attention_mask, - "streamer": streamer, - "max_new_tokens": max_new_tokens, - "do_sample": temperature > 0, - "temperature": temperature, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.eos_token_id, - }) - - thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) - thread.start() + input_tokens = len(tokenizer.encode(full_prompt)) + if input_tokens > self.engine_args.max_model_len: + raise ValueError( + f"Input sequence length ({input_tokens}) exceeds maximum allowed ({self.engine.engine_args.max_model_len})") total_tokens = 0 + current_response = "" + first_token_time = None + try: - for text in streamer: - total_tokens += 1 - yield text - await asyncio.sleep(0) # Allow other tasks to run + async for output in results_generator: + if output.outputs: + if first_token_time is None: + first_token_time = time.time() + + generated_text = output.outputs[0].text + delta = generated_text[len(current_response):] + current_response = generated_text + total_tokens += len(tokenizer.encode(delta)) + + yield { + "choices": [{ + "delta": {"content": delta}, + "finish_reason": None + }], + "created": int(time.time()), + "model": self.model_id, + "id": request_id + } + + await asyncio.sleep(0) + + # Final message + end_time = time.time() + duration = end_time - start_time + logger.info(f"Generation completed in {duration:.2f}s") + logger.info( + f" Time to first token: {(first_token_time - start_time):.2f} seconds") + logger.info(f" Total tokens: {total_tokens}") + logger.info(f" Prompt tokens: {input_tokens}") + logger.info(f" Generated tokens: {total_tokens}") + generation_time = end_time - first_token_time if first_token_time else 0 + logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}") + yield { + "choices": [{ + "delta": {"content": ""}, + "finish_reason": "stop" + }], + "created": int(time.time()), + "model": self.model_id, + "id": request_id, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": total_tokens, + "total_tokens": input_tokens + total_tokens + } + } + except Exception as e: - logger.error(f"Error during streaming: {str(e)}") + if "CUDA out of memory" in str(e): + logger.error( + "GPU memory exhausted, consider reducing batch size or model parameters") + elif "tokenizer" in str(e).lower(): + logger.error("Tokenizer error, check input text format") + else: + logger.error(f"Error generating response: {e}") raise - input_length = input_ids.size(1) - yield {"tokens_used": input_length + total_tokens} + async def __call__( + self, + messages: List[Dict[str, str]], + **kwargs + ) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + """ + Generate responses for messages. + + Args: + messages: List of message dictionaries in OpenAI format + **kwargs: Generation parameters + """ + logger.debug(f"Generating response for messages: {messages}") + start_time = time.time() - def model_generate_wrapper(self, **kwargs): try: - logger.debug("Entering model.generate") - with torch.cuda.amp.autocast(): # Use automatic mixed precision - self.model.generate(**kwargs) - logger.debug("Exiting model.generate") + # Validate inputs + self.validate_messages(messages) + config = GenerationConfig(**kwargs) + config.validate() + + async for response in self.generate(messages, config): + yield response + except Exception as e: - logger.error(f"Error in model.generate: {str(e)}", exc_info=True) + logger.error(f"Error in pipeline: {e}") raise def __str__(self): return f"LLMPipeline(model_id={self.model_id})" + + def _find_model_path(self, base_path): + # Check if the model files are directly in the base path + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)): + return base_path + + # If not, look in subdirectories + for root, dirs, files in os.walk(base_path): + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files): + return root diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index 6155d0ff..2bafb448 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -14,4 +14,5 @@ is_turbo_model, split_prompt, validate_torch_device, + get_max_memory, ) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 714aeb4d..7379afb6 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -4,6 +4,7 @@ import logging import os import re +import psutil from pathlib import Path from typing import Any, Dict, List, Optional @@ -365,3 +366,25 @@ def enable_loras(self) -> None: if not self.loras_enabled: self.pipeline.enable_lora() self.loras_enabled = True + + +class MemoryInfo: + def __init__(self, gpu_memory, cpu_memory, num_gpus): + self.gpu_memory = gpu_memory + self.cpu_memory = cpu_memory + self.num_gpus = num_gpus + + def __repr__(self): + return f"" + + +def get_max_memory() -> MemoryInfo: + num_gpus = torch.cuda.device_count() + gpu_memory = { + i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} + cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" + + memory_info = MemoryInfo(gpu_memory=gpu_memory, + cpu_memory=cpu_memory, num_gpus=num_gpus) + + return memory_info diff --git a/runner/app/routes/llm.py b/runner/app/routes/llm.py index 1080280c..076efef4 100644 --- a/runner/app/routes/llm.py +++ b/runner/app/routes/llm.py @@ -1,12 +1,12 @@ import logging import os -from typing import Annotated -from fastapi import APIRouter, Depends, Form, status +import time +from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.utils import HTTPError, LLMResponse, http_error +from app.routes.utils import HTTPError, LLMRequest, LLMResponse, http_error import json router = APIRouter() @@ -33,13 +33,7 @@ ) @router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False) async def llm( - prompt: Annotated[str, Form()], - model_id: Annotated[str, Form()] = "", - system_msg: Annotated[str, Form()] = "", - temperature: Annotated[float, Form()] = 0.7, - max_tokens: Annotated[int, Form()] = 256, - history: Annotated[str, Form()] = "[]", # We'll parse this as JSON - stream: Annotated[bool, Form()] = False, + request: LLMRequest, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -52,50 +46,50 @@ async def llm( content=http_error("Invalid bearer token"), ) - if model_id != "" and model_id != pipeline.model_id: + if request.model != "" and request.model != pipeline.model_id: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( - f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"pipeline configured with {pipeline.model_id} but called with {request.model}" ), ) try: - history_list = json.loads(history) - if not isinstance(history_list, list): - raise ValueError("History must be a JSON array") - generator = pipeline( - prompt=prompt, - history=history_list, - system_msg=system_msg if system_msg else None, - temperature=temperature, - max_tokens=max_tokens + messages=[msg.dict() for msg in request.messages], + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + top_k=request.top_k ) - if stream: - return StreamingResponse(stream_generator(generator), media_type="text/event-stream") + if request.stream: + return StreamingResponse( + stream_generator(generator), + media_type="text/event-stream" + ) else: full_response = "" + last_chunk = None + async for chunk in generator: if isinstance(chunk, dict): - tokens_used = chunk["tokens_used"] - break - full_response += chunk + if "choices" in chunk: + if "delta" in chunk["choices"][0]: + full_response += chunk["choices"][0]["delta"].get( + "content", "") + last_chunk = chunk - return LLMResponse(response=full_response, tokens_used=tokens_used) + usage = last_chunk.get("usage", {}) + + return LLMResponse( + response=full_response, + tokens_used=usage.get("total_tokens", 0), + id=last_chunk.get("id", ""), + model=last_chunk.get("model", pipeline.model_id), + created=last_chunk.get("created", int(time.time())) + ) - except json.JSONDecodeError: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "Invalid JSON format for history"} - ) - except ValueError as ve: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(ve)} - ) except Exception as e: logger.error(f"LLM processing error: {str(e)}") return JSONResponse( @@ -107,11 +101,12 @@ async def llm( async def stream_generator(generator): try: async for chunk in generator: - if isinstance(chunk, dict): # This is the final result - yield f"data: {json.dumps(chunk)}\n\n" - break - else: - yield f"data: {json.dumps({'chunk': chunk})}\n\n" + if isinstance(chunk, dict): + if "choices" in chunk: + # Regular streaming chunk or final chunk + yield f"data: {json.dumps(chunk)}\n\n" + if chunk["choices"][0].get("finish_reason") == "stop": + break yield "data: [DONE]\n\n" except Exception as e: logger.error(f"Streaming error: {str(e)}") diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 8868989b..6f8271c0 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -72,9 +72,27 @@ class TextResponse(BaseModel): chunks: List[Chunk] = Field(..., description="The generated text chunks.") +class LLMMessage(BaseModel): + role: str + content: str + + +class LLMRequest(BaseModel): + messages: List[LLMMessage] + model: str = "" + temperature: float = 0.7 + max_tokens: int = 256 + top_p: float = 1.0 + top_k: int = -1 + stream: bool = False + + class LLMResponse(BaseModel): response: str tokens_used: int + id: str + model: str + created: int class ImageToTextResponse(BaseModel): @@ -101,6 +119,7 @@ class LiveVideoToVideoResponse(BaseModel): description="URL for subscribing to events for pipeline status and logs", ) + class APIError(BaseModel): """API error response model.""" diff --git a/runner/docker/Dockerfile.llm b/runner/docker/Dockerfile.llm new file mode 100644 index 00000000..5ab4b967 --- /dev/null +++ b/runner/docker/Dockerfile.llm @@ -0,0 +1,74 @@ +# Based on https://github.com/huggingface/api-inference-community/blob/main/docker_images/diffusers/Dockerfile + +FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu20.04 +LABEL maintainer="Yondon Fu " + +# Add any system dependency here +# RUN apt-get update -y && apt-get install libXXX -y + +ENV DEBIAN_FRONTEND=noninteractive + +# Install prerequisites +RUN apt-get update && \ + apt-get install -y build-essential libssl-dev zlib1g-dev libbz2-dev \ + libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \ + xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git \ + ffmpeg + +# Install pyenv +RUN curl https://pyenv.run | bash + +# Set environment variables for pyenv +ENV PYENV_ROOT /root/.pyenv +ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH + +# Install your desired Python version +ARG PYTHON_VERSION=3.11 +RUN pyenv install $PYTHON_VERSION && \ + pyenv global $PYTHON_VERSION && \ + pyenv rehash + +# Upgrade pip and install your desired packages +ARG PIP_VERSION=24.2 +RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 && \ + pip install --no-cache-dir torch==2.4.0 torchvision torchaudio pip-tools + +WORKDIR /app + +COPY ./requirements.llm.in /app +RUN pip-compile requirements.llm.in -o requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Most DL models are quite large in terms of memory, using workers is a HUGE +# slowdown because of the fork and GIL with python. +# Using multiple pods seems like a better default strategy. +# Feel free to override if it does not make sense for your library. +ARG max_workers=1 +ENV MAX_WORKERS=$max_workers +ENV HUGGINGFACE_HUB_CACHE=/models +ENV DIFFUSERS_CACHE=/models +ENV MODEL_DIR=/models +# This ensures compatbility with how GPUs are addressed within go-livepeer +ENV CUDA_DEVICE_ORDER=PCI_BUS_ID + +# vLLM configuration +ENV USE_8BIT=false +ENV MAX_NUM_BATCHED_TOKENS=8192 +ENV MAX_NUM_SEQS=128 +ENV MAX_MODEL_LEN=8192 +ENV GPU_MEMORY_UTILIZATION=0.85 +ENV TENSOR_PARALLEL_SIZE=1 +ENV PIPELINE_PARALLEL_SIZE=1 +# To use multiple GPUs, set TENSOR_PARALLEL_SIZE and PIPELINE_PARALLEL_SIZE +# Total GPUs used = TENSOR_PARALLEL_SIZE × PIPELINE_PARALLEL_SIZE +# Example for 4 GPUs: +# - Option 1: TENSOR_PARALLEL_SIZE=2, PIPELINE_PARALLEL_SIZE=2 +# - Option 2: TENSOR_PARALLEL_SIZE=4, PIPELINE_PARALLEL_SIZE=1 +# - Option 3: TENSOR_PARALLEL_SIZE=1, PIPELINE_PARALLEL_SIZE=4 + +COPY app/ /app/app +COPY images/ /app/images +COPY bench.py /app/bench.py +COPY example_data/ /app/example_data + +CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"] diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index ec565f3b..c6ae67cb 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -319,9 +319,9 @@ paths: operationId: genLLM requestBody: content: - application/x-www-form-urlencoded: + application/json: schema: - $ref: '#/components/schemas/Body_genLLM' + $ref: '#/components/schemas/LLMRequest' required: true responses: '200': @@ -762,40 +762,6 @@ components: - image - model_id title: Body_genImageToVideo - Body_genLLM: - properties: - prompt: - type: string - title: Prompt - model_id: - type: string - title: Model Id - default: '' - system_msg: - type: string - title: System Msg - default: '' - temperature: - type: number - title: Temperature - default: 0.7 - max_tokens: - type: integer - title: Max Tokens - default: 256 - history: - type: string - title: History - default: '[]' - stream: - type: boolean - title: Stream - default: false - type: object - required: - - prompt - - model_id - title: Body_genLLM Body_genSegmentAnything2: properties: image: @@ -1033,6 +999,54 @@ components: - text title: ImageToTextResponse description: Response model for text generation. + LLMMessage: + properties: + role: + type: string + title: Role + content: + type: string + title: Content + type: object + required: + - role + - content + title: LLMMessage + LLMRequest: + properties: + messages: + items: + $ref: '#/components/schemas/LLMMessage' + type: array + title: Messages + model: + type: string + title: Model + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + top_p: + type: number + title: Top P + default: 1.0 + top_k: + type: integer + title: Top K + default: -1 + stream: + type: boolean + title: Stream + default: false + type: object + required: + - messages + title: LLMRequest LLMResponse: properties: response: @@ -1041,10 +1055,22 @@ components: tokens_used: type: integer title: Tokens Used + id: + type: string + title: Id + model: + type: string + title: Model + created: + type: integer + title: Created type: object required: - response - tokens_used + - id + - model + - created title: LLMResponse LiveVideoToVideoParams: properties: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index f5a19f1f..fe6e13a0 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -319,9 +319,9 @@ paths: operationId: genLLM requestBody: content: - application/x-www-form-urlencoded: + application/json: schema: - $ref: '#/components/schemas/Body_genLLM' + $ref: '#/components/schemas/LLMRequest' required: true responses: '200': @@ -796,39 +796,6 @@ components: required: - image title: Body_genImageToVideo - Body_genLLM: - properties: - prompt: - type: string - title: Prompt - model_id: - type: string - title: Model Id - default: '' - system_msg: - type: string - title: System Msg - default: '' - temperature: - type: number - title: Temperature - default: 0.7 - max_tokens: - type: integer - title: Max Tokens - default: 256 - history: - type: string - title: History - default: '[]' - stream: - type: boolean - title: Stream - default: false - type: object - required: - - prompt - title: Body_genLLM Body_genSegmentAnything2: properties: image: @@ -1178,6 +1145,54 @@ components: - text title: ImageToTextResponse description: Response model for text generation. + LLMMessage: + properties: + role: + type: string + title: Role + content: + type: string + title: Content + type: object + required: + - role + - content + title: LLMMessage + LLMRequest: + properties: + messages: + items: + $ref: '#/components/schemas/LLMMessage' + type: array + title: Messages + model: + type: string + title: Model + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + top_p: + type: number + title: Top P + default: 1.0 + top_k: + type: integer + title: Top K + default: -1 + stream: + type: boolean + title: Stream + default: false + type: object + required: + - messages + title: LLMRequest LLMResponse: properties: response: @@ -1186,10 +1201,22 @@ components: tokens_used: type: integer title: Tokens Used + id: + type: string + title: Id + model: + type: string + title: Model + created: + type: integer + title: Created type: object required: - response - tokens_used + - id + - model + - created title: LLMResponse LiveVideoToVideoParams: properties: diff --git a/runner/requirements.llm.in b/runner/requirements.llm.in new file mode 100644 index 00000000..bd148088 --- /dev/null +++ b/runner/requirements.llm.in @@ -0,0 +1,22 @@ +vllm==0.6.5 +diffusers +accelerate +transformers +fastapi +pydantic +Pillow +python-multipart +uvicorn +huggingface_hub +xformers +triton +peft +deepcache +safetensors +scipy +numpy +av +sentencepiece +protobuf +bitsandbytes +psutil diff --git a/runner/requirements.txt b/runner/requirements.txt index 17237444..12107f9b 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -21,4 +21,4 @@ bitsandbytes==0.43.3 psutil==6.0.0 PyYAML==6.0.2 nvidia-ml-py==12.560.30 -pynvml==12.0.0 +pynvml==12.0.0 \ No newline at end of file diff --git a/worker/docker.go b/worker/docker.go index 2bb965ae..6ece9586 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -62,6 +62,7 @@ var pipelineToImage = map[string]string{ "segment-anything-2": "livepeer/ai-runner:segment-anything-2", "text-to-speech": "livepeer/ai-runner:text-to-speech", "audio-to-text": "livepeer/ai-runner:audio-to-text", + "llm": "livepeer/ai-runner:llm", } var livePipelineToImage = map[string]string{ diff --git a/worker/multipart.go b/worker/multipart.go index 25b00341..f8b93844 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -259,56 +259,6 @@ func NewAudioToTextMultipartWriter(w io.Writer, req GenAudioToTextMultipartReque return mw, nil } -func NewLLMMultipartWriter(w io.Writer, req BodyGenLLM) (*multipart.Writer, error) { - mw := multipart.NewWriter(w) - - if err := mw.WriteField("prompt", req.Prompt); err != nil { - return nil, fmt.Errorf("failed to write prompt field: %w", err) - } - - if req.History != nil { - if err := mw.WriteField("history", *req.History); err != nil { - return nil, fmt.Errorf("failed to write history field: %w", err) - } - } - - if req.ModelId != nil { - if err := mw.WriteField("model_id", *req.ModelId); err != nil { - return nil, fmt.Errorf("failed to write model_id field: %w", err) - } - } - - if req.SystemMsg != nil { - if err := mw.WriteField("system_msg", *req.SystemMsg); err != nil { - return nil, fmt.Errorf("failed to write system_msg field: %w", err) - } - } - - if req.Temperature != nil { - if err := mw.WriteField("temperature", fmt.Sprintf("%f", *req.Temperature)); err != nil { - return nil, fmt.Errorf("failed to write temperature field: %w", err) - } - } - - if req.MaxTokens != nil { - if err := mw.WriteField("max_tokens", strconv.Itoa(*req.MaxTokens)); err != nil { - return nil, fmt.Errorf("failed to write max_tokens field: %w", err) - } - } - - if req.Stream != nil { - if err := mw.WriteField("stream", fmt.Sprintf("%v", *req.Stream)); err != nil { - return nil, fmt.Errorf("failed to write stream field: %w", err) - } - } - - if err := mw.Close(); err != nil { - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - return mw, nil -} - func NewSegmentAnything2MultipartWriter(w io.Writer, req GenSegmentAnything2MultipartRequestBody) (*multipart.Writer, error) { mw := multipart.NewWriter(w) writer, err := mw.CreateFormFile("image", req.Image.Filename()) diff --git a/worker/runner.gen.go b/worker/runner.gen.go index adaa6f1d..1a06c48d 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -189,17 +189,6 @@ type BodyGenImageToVideo struct { Width *int `json:"width,omitempty"` } -// BodyGenLLM defines model for Body_genLLM. -type BodyGenLLM struct { - History *string `json:"history,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - ModelId *string `json:"model_id,omitempty"` - Prompt string `json:"prompt"` - Stream *bool `json:"stream,omitempty"` - SystemMsg *string `json:"system_msg,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` -} - // BodyGenSegmentAnything2 defines model for Body_genSegmentAnything2. type BodyGenSegmentAnything2 struct { // Box A length 4 array given as a box prompt to the model, in XYXY format. @@ -326,8 +315,28 @@ type ImageToTextResponse struct { Text string `json:"text"` } +// LLMMessage defines model for LLMMessage. +type LLMMessage struct { + Content string `json:"content"` + Role string `json:"role"` +} + +// LLMRequest defines model for LLMRequest. +type LLMRequest struct { + MaxTokens *int `json:"max_tokens,omitempty"` + Messages []LLMMessage `json:"messages"` + Model *string `json:"model,omitempty"` + Stream *bool `json:"stream,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float32 `json:"top_p,omitempty"` +} + // LLMResponse defines model for LLMResponse. type LLMResponse struct { + Created int `json:"created"` + Id string `json:"id"` + Model string `json:"model"` Response string `json:"response"` TokensUsed int `json:"tokens_used"` } @@ -497,8 +506,8 @@ type GenImageToVideoMultipartRequestBody = BodyGenImageToVideo // GenLiveVideoToVideoJSONRequestBody defines body for GenLiveVideoToVideo for application/json ContentType. type GenLiveVideoToVideoJSONRequestBody = LiveVideoToVideoParams -// GenLLMFormdataRequestBody defines body for GenLLM for application/x-www-form-urlencoded ContentType. -type GenLLMFormdataRequestBody = BodyGenLLM +// GenLLMJSONRequestBody defines body for GenLLM for application/json ContentType. +type GenLLMJSONRequestBody = LLMRequest // GenSegmentAnything2MultipartRequestBody defines body for GenSegmentAnything2 for multipart/form-data ContentType. type GenSegmentAnything2MultipartRequestBody = BodyGenSegmentAnything2 @@ -679,7 +688,7 @@ type ClientInterface interface { // GenLLMWithBody request with any body GenLLMWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) - GenLLMWithFormdataBody(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + GenLLM(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) // GenSegmentAnything2WithBody request with any body GenSegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -830,8 +839,8 @@ func (c *Client) GenLLMWithBody(ctx context.Context, contentType string, body io return c.Client.Do(req) } -func (c *Client) GenLLMWithFormdataBody(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewGenLLMRequestWithFormdataBody(c.Server, body) +func (c *Client) GenLLM(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenLLMRequest(c.Server, body) if err != nil { return nil, err } @@ -1180,15 +1189,15 @@ func NewGenLiveVideoToVideoRequestWithBody(server string, contentType string, bo return req, nil } -// NewGenLLMRequestWithFormdataBody calls the generic GenLLM builder with application/x-www-form-urlencoded body -func NewGenLLMRequestWithFormdataBody(server string, body GenLLMFormdataRequestBody) (*http.Request, error) { +// NewGenLLMRequest calls the generic GenLLM builder with application/json body +func NewGenLLMRequest(server string, body GenLLMJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader - bodyStr, err := runtime.MarshalForm(body, nil) + buf, err := json.Marshal(body) if err != nil { return nil, err } - bodyReader = strings.NewReader(bodyStr.Encode()) - return NewGenLLMRequestWithBody(server, "application/x-www-form-urlencoded", bodyReader) + bodyReader = bytes.NewReader(buf) + return NewGenLLMRequestWithBody(server, "application/json", bodyReader) } // NewGenLLMRequestWithBody generates requests for GenLLM with any type of body @@ -1433,7 +1442,7 @@ type ClientWithResponsesInterface interface { // GenLLMWithBodyWithResponse request with any body GenLLMWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) - GenLLMWithFormdataBodyWithResponse(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) + GenLLMWithResponse(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) // GenSegmentAnything2WithBodyWithResponse request with any body GenSegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSegmentAnything2Response, error) @@ -1905,8 +1914,8 @@ func (c *ClientWithResponses) GenLLMWithBodyWithResponse(ctx context.Context, co return ParseGenLLMResponse(rsp) } -func (c *ClientWithResponses) GenLLMWithFormdataBodyWithResponse(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) { - rsp, err := c.GenLLMWithFormdataBody(ctx, body, reqEditors...) +func (c *ClientWithResponses) GenLLMWithResponse(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) { + rsp, err := c.GenLLM(ctx, body, reqEditors...) if err != nil { return nil, err } @@ -3195,92 +3204,87 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xde3Pbtpb/KhjuzjiZkWTZbdpdz9w/nDRNPOukHj9u20k9uhB5RCEmAV4AtKxm/d13", - "cMAHSIJ6uLbbTfVXZBKP8/ydA+AQ+RKEIs0EB65VcPQlUOEcUoo/j89O3koppPkdgQolyzQTPDgybwiY", - "V0SCygRXQFIRQTIKBkEmRQZSM8AxUhV3u1/OoeieglI0BtNPM51AcBR8ULH5a5mZP5SWjMfB/f0gkPDv", - "nEmIgqNPOOp13aUitOonpp8h1MH9IDjOIybOCyq7pJw36CczIQk1PUgMHCQ1rbpMYQv8kSQ/zYKjT1+C", - "/5QwC46C/9ivpblfiHL/A0SMXp2fBvfXA48kipkgsjOPOtza6Vx+Gzx5mH4touUkBo4NL8Ul3GlDbg8X", - "TZKuskTQqKSGzFgCRAsyBaIl5ablFCIjk5mQKdXBUTBlnMpl0KKvq8RBkIKmEdXUzjqjeWL6f7kP2nI5", - "jiJmftKEfBZTwridjAle0JJRpSAyf+g5kIxlkDDetKNyLh8dRtkTFjXp6FDxPo9jxmPyIw1LAzn5geRm", - "YmMopTyy0kqqqW3TyDe1BJ1LPtEsBaVpmqkmDVrm0KHjHPuQuo+dft5QCdFwp0fkIs8yIY013dIkB3VE", - "9hRwDTyEvQHZWwgZ7Q2IMXNiiSJTIRKgnLzYM5PvmXd7M5oo2Hs5Ij9YyghTpHj9oh7v5ahsSVKgXBEu", - "HCJHxWzFO/N7OKWotbqNI7WCy8taMutgoOMYPrtf4R4nKY3hUuA/Xf+IcxZRHsJEhTSBhpq+H71q6+gt", - "D0UuaQyqsBRdYQgQluKLMBEKkiVJGL+pjdfojWRSpJkmL+YsnoMsdEdSuiQSojwshiD/zmnC9PKlK7d3", - "BZ3kAums+OV5OgVp+GUlgz2ebsfWwlDOZkuyYHre8at+d7fy89g6jjtZIceDrhx/gFgCErOYs9CSUSOk", - "pZQpkuVqjiJcUBkpbMU404wmts2oTR9ZL6ZESKrWQMIxORXnx+TFqVgMzym/IccRzTQi08tC8ZRHhGlF", - "QiFtdIyMly2AxXONjmuZcAIMeXtH0yyBI/KF/BYkVAPXw1BwxZRxtOV+EqZDQ91QRXfJb8ERORiNB+S3", - "gINkn9V+xu4gGVKph+Xbw3tXAKfI2JPhYIefDaGQQ0w1u4WJNf41RFzWbvJCvUT3ylkEZDGn2vwFd2GS", - "R0BmUqQeEZ/EXEhjQTPSNEjyWz4efxOSA5fsjwVp5MyS5qM+TyfWrycZSB8PB20WPqKpETErAcHFiAxk", - "wV6DkDwlJ7bxGcgOOYxriK31Ij18BhKQNQ2t0HIwHvfTEwEXTBkdY8cR+SAk2N8kVzlNDGoBRcwqIKqA", - "opKVaa6JSsQCJKmoMMNEeYKeO12aeAM81vMOf2V7coFU+7hzxbuJVayyyX6dKjoDvZyEcwhvGsIzoa8t", - "vTOQBhNNIMVuBLuhKSrNUsT9WRu7DCzkSWRSGDGbAVfGyIQkcyrTWZ64ZF7YUd8gMRWxRbRGagGirkQu", - "oHBLSXkkUmLxrUcUprFX3qWuGlIYj/6rB67FzKYidZpGsyxhdZCTUOrYaubF2Lw5aASyi3LODja34n5W", - "KtAGNk8C0IjsG2YA74yEWNhNBELBtRRJHciARy25/Lcvsc9AhsC1YdfIR2iaFC5FtRPd3tjhP4ImSgvz", - "NsuSJeOxK5uiUR3F3vLIF8MKWjnoSWgijyGH8dgTe8e+2KtBpoyDInOxIGkezsu4pQWhSrGYlwp1RyeM", - "Z7lWHno5aMNf3bI3+K7IEoJP341eDcjBeDS+Dv76eVcr//kT8q6/YSJD1c2kR9wfqLqpRR2VNl54oISY", - "Ca6Mh1JeN5uxJCGMWzZ5RhnXhnlN1c1qpeBsvZrZJVx/n4Qr+PTNeEAOX3Ux66+ed9Uy36Vdu7SrL+1y", - "0GxNBlZmVusTMf9O5cZxtBLGo4XSR9wprJS0KVz/IZTun7LPD1vaXpdbb7i59k8WgeiqdNYCy+98CfRM", - "0hQUArUCk3CiwTc2hG7N8C53P/YsIOeYtjTmfPW9d1bb0gR/TEfUBpO+t4P75t3Ydqu4RO34GFf/VKu1", - "ZGyfZqTCtJ5M8/AGdJuKg8Pv22RclRM21hRG5DQVOddGAXbMat/bTTRQZzY2mlcF8JqfqQmmRc+FSeim", - "YNRqXnVU+ME2e41ENxhzQ75gCiY0jyc9QD0+7OTZFQvYmdAoquG5uYjCfUvyvrESKVYhEhSk0wSXG719", - "bcLOQwlUlXw3Yj4ScJzHpB/y16c1h6/+H28j7TKNUhILFrWs92B8+K0PD7HlVnD4M47dnXXLCGNDx4oQ", - "c3r6oRtZ5kxpIZetVPzaReuihXcdeTfR4gZ42+a/c1d4d+TStvEJthd7twv5G2TNWgJNG9PgYVwzs6Op", - "37SWSkM6qY7nPXReYBPiPY8fBBrSzOg/l+1tpe/rIS6dRhtmlx5zMGpeYQUXEKfA9TFf6jnj8WHXJKbi", - "zlPDQBKEEfItoVLSJYnZLXBCFaFkKu7KnaECbVGrA+MFv/z6y6/ExmTX5l+Luy22gE7KqK8s8Q+N87jt", - "wbNce/kTi6EEJZIcQ1uK+yCmcYspvcxYiNiMS3lKMgm3TOTK/IhYiL2ZLtBlUOfWiI4Hd+/vfiYv3v/j", - "538cvvoOgeni+ENjhWH3RJDMv9yeSJonBsvVzUTkuhLkiqhwYtZcOQxqCdrcQhaH9HOzMDMD2lN6mk5Z", - "nBthWtFbs1IDImYauPkzykM8hgetQRY99ZxyE3cYjxNw1NDgqqSc/GQp9/k5N0aVsN9hEgohI7Ude5lg", - "XBPsyTjVoKo0qhq3XmpSHgP5NB4cXBcmgr2LeQncZRBq23wKtoEEZR6aR1Z9EUtNxBRcNfOWYi7yxvLg", - "Y9SdrOsMH+8OCy8Xs4KrQhEtX1jMQQIBGhbkE2YUR178Mvj1ZR0DG8spbNamzIF0JCyhU0g8hJ3i8yqv", - "bZBWUnNAGI9YiPKnpinEUuQ8KlqbrG/caDKl4Y3bpEuunXZFfUoiYqa3sBbbTZGcD40HqLlITJ6L5mnH", - "IowrbXI/MTMkIsbhe08NyKmdvavnTTOITkxYET+usurI4YHbDo+8ff84gJhbtqKH7xavWQh8/+pvdJ68", - "kTR3O5zr1h1bH+SWzunx3zfznN/48p7QvMBlilEmeiWta8665Zy62HTsLn1wgGK9g6O6LDY3wJzMuJyp", - "Z8zydWdgpiE1BN07c1RjVRNhGOtIUrsNDWGOLK2gPBJ8d3b1RqRZruGEzzyloB+qmtgINGXG/N+dXZHQ", - "9nGrMrtCtfBVYZ0/96KfbXVxnSh+dqt43VUVpEIuJzMJ0OiAj8mP5vGKbnj27ul3ic+9HRlvkYYPvJtC", - "NG3Q9NH8vXZ31QiE25YNIpusljIqCXK02lKeX71XmiXsd1TROhUbzeZ1c6I01UxpFqoHKveZNbaZGgaB", - "w+OksGS3myMxUsjXO507jKW5bxRL+QbbIZuahI8DL0FNa2nbgsdi3l9envWU/JtXG9b8W7DYvD6+Kt/v", - "1sf/UOKOnbmBOG35FdM6TNfs9PD6T5qwCIeruO5jpQTnlZy0x3OQ3HLig3GX2vYAPrqpjBZUotcXstjo", - "swbj3ysRO85yk+/ZDxuq+vuzRptVzLcAyeHsXZaTPpNzk92NdsiK3LqxR1Y+W4e6Wd2wmndQM+6ajkfK", - "K5RxoalWG6lBAk2GJlqjQlaBrCFLleM+UCFtn28pxZL919WKZd+jll7C3wNN9PxNmWc3JWqGy5U/LZtj", - "R2KblKmZQxnwPDXE/vQ/wSB4e37+03kwCE5+OH3rkndhJ1jHcEGHy5dDtocrXCpu9RGTb5niWd32CKO9", - "tnCz0/WfOrlWZktp1mWvBS3XrX6rPnFyTqG3Egym9Kvk0r8aqKWCn9qsXQq0U/RWWu7jwMPo6ekHl8Em", - "sdJ5U++htAdz1iZ4bjHJFTRc2x5nkCu1yaJNOuM7wzmcuST7OGK3gGc7xRHPGZXUmpa/7jaXyZqtkKvz", - "U1Suyqf4NRTjMblllFxKFt7gBqrQIhRJsTES4T5ZcRKcmIUyHmENtRi2z7xJhsT5SnGvZOITMdwaj9iC", - "6CyfJkzNDc22bz/pJR6VMEV5RBIRN8h7a8fooW7DvSWTPLdBkGhBZM67cjMv7I/PYjoiH4VmIRCNZYJz", - "pghTxKS1ESknLw/7y7JQu0YXeg6SSJFrUAPc/2GaRAIU4ULbSiYsQybejS9bBwB3NNT22Qv1kkSQAY8U", - "EbzJCUuzBFLguiia4hFJsehjiufqMxbnkk4TQE2Ynv+yZvAvQmWclwc2GwXGyrYraZv1fWsXtfhyCRuD", - "BulsCHu+bCz8xeNZhSnVxtfMpZVm3HJsbK9Qr8h1LOxOnQSaGhEXwzTmtI/6zKp0PPBPfSFyGYI7K+Oh", - "SJuzVmMQ3ThNvqieeydvx9YGJU2RuBDlx6AN0GqraLMaXLrBZ3vIQzQr62VWT/eEGOYCrxYljq2CrI0R", - "63Ftmmjx51n1cxv1qiD8gaobtZUt277liXWPAbvnRu38SdLFgOTcOTqsDzYVeWG7vqygD09Cm9X/zVOh", - "5jn42lS2Mx6KwKv3UMi+1BjlsadsmIjwlME2R7rx4LA5ZQPK7MBrrzcoCFNl80Kq1y3aV+oXM3HP7nhq", - "XpTKNHhDWfENj1NFPBW5blX5YL+uwrmaLbrT/DwHXRbk2QkXVJFZQuMYIkIV+Xjx48+NcxszzOZnEUYT", - "5o097nKrJ6sZN6qC8vq1Gdw4tT19rVkIKTcJAg1DUMregVDt923gxNZ1lSUFxebqE9XVp8er81OfKhF9", - "pUiLT1l6qWxq7Ll5bnNpmPEw+vgrOTxYUZus5ewZzObLXHuict86YukucwdPvJoclDxeN3uvAgbzvqjL", - "71t4fT03Hzxm1XfnXoEVVd+7qwR2X7Z9vV+2vfpb3yRALsAs1DUQrB3M7CYN1pLhPsbe/+4Z01DVRTzT", - "ZV1htisX+dPK1Dv4vWGZercwuRtCe+PsRQYQzvsCbYMLF7KOSWrwRGVAb0CSCMzKXiqj48SAf7IkcJdJ", - "UKg3EyYoR1VHpg+E87LsxRgd2qp5HGHLjOkQPaezlC7/MrIrpzZLWA1QpFvmLzu+X4/OIE/4Nd0mlKyK", - "FnVStjpE2Ipd3N9YNVVvvta0l4YpeAxm7WF0IsLGSTTly+J0vc3hl45NX9+7MTxsHWhW6WrxPUCdr+P1", - "fF4Z4oO6KdJMLs3Tdamr4cNOVbR0XGuDA/Dtd+HW77vZLw3XJerld3mmbWOtsOXZWHuNUH66aIlYc1ZW", - "kOrKbPVeDyJ0mEumlxeGFMvn+8vLs9dAJcjq1keEdfuoGmSudRbcmzGYt37ouPjkOKwu55M5J8cn1b6f", - "u9F3ym4hM1hyfELOc85xIoNrdqzxaDwaG4GIDDjNWHAUfDM6GI2NtqieI9n7eOfbUIth6cSZUL5oXl2M", - "59xjaIu9i9WWyAprOInMUqJ9aZwROSj9WkTLcmMWOE5koz6Vet+E3WF5n6FV8zoj8N1Qd99UsYnx+MAq", - "FNk+HI9bVDhS3/+sbPzYjITGAhHnbgXuHBf7szwhdbNB8O0jklAX5njmf00jcm6lb+c9eJ55rzjN9VxI", - "9jtEOPHBN88zccEsecu1SYMvhSCnVMZW6gevnov7OmFFpLJYbkg4PHxUEjpFUl1i6iakKqR69Vz2d8I1", - "SE4TcgHyFmRJgQOjGHNdAP10fX89CFSeplQuywtQyaUgZWpAY2WwuwwlBr3vhjbFomo55DSFobgFKVmE", - "yN9Ah0GwPy/qXvZLFI4BRdAEMbdoKXhCBPEVR20KJPeunMqBbHVYk9Oq9Gklq2Uh0JPzaif6Y1yWYxg2", - "seCnnz37+in5ciqOHsaVJRG5waWVCcrVJzP+qHycZcmy/G6mcUGFskf7mRQmyXIWa50w3bra7YnjdGO2", - "Zw7UzRqoXaTuj9S7CLVthLIfIF+K+qqyLUMUazpGBwSGsXOd4mOAAaFhiN+hxtUNGHPBQlumRNWNGhA2", - "gtGAiFyX97UNnLvbBkTdgA7nh5tBS31n0TMiTDnpDmh2QPNVAg1xrgL7A4BT+4mLOxvsCOBGuYWT9RsC", - "zYuungcG/owNAV8h8g4E/uL7AjskejASPXBRzhoe6gLPbXXHnRd53vludttqsVPehPQ8GGRne2YQam5i", - "7+Bnl4M8gedXN4o9zPVLxxgE+wm7hWGz0nrdSse7xnG+orA1w+7drTqXHCICPMJrXJQXItpFvyth4uE6", - "6imYf2aU6K1w3gHGDjAeDzCMmVmw+COokbQ90yJHkm6QKmCNQ451VJQklMe5gbCqhKiLAnhH3maOfzdc", - "LBZDzBNymQAPRWQLeLbLFsyUz+3+zoeFO4/fefwjery9Y3JbD09S69TFRzFDWlw4Njzs9/HibrLiEwz8", - "DrX8nye8ru25y+yJ1wGdGZ/ZzZsft+wcfefoj+fopfeVxk0OH+D3qusgg2DfxOwNDkHftb6NwN0A51MI", - "f5rv1Jw+UYbfrWrdHUPs3P4rcXus5/0Dx53acb+Gs9vK4I02/5pd3P9f2f63uOUtBeW2oK5rkCmPnGLw", - "xn863IMUttr4SaGiUdD8zFjR/C+wd1ixw4rHx4rKhR4GFkV3RIvcuWPYCxPFPaf1/0E3XZb/lQd+nK0V", - "qa9y97p9fVPqE68Oyol22cHO478Sj3duGd7S1XPXGRQSoHC61jXv5ZcPbxKRR+SNSNOcM70k76iGBV0G", - "xVUE+L2FOtrfjyTQdBjbt6Ok6D4KTXf8wKdn/AuNWUXfsNVACtvt04ztT0HT/Yrf++v7/wsAAP//DDfM", - "rxOEAAA=", + "H4sIAAAAAAAC/+xdeW/jtrb/KoTeA5IB7GzttA8B7h+ZpTPBTaZBljst2sCXlo5lTiRSJakk7rx89wcu", + "kkiJsuU0Sft6/dc4Epez/s4heaj5GsUsLxgFKkV0+DUS8RxyrH8enR2/55xx9TsBEXNSSMJodKjeIFCv", + "EAdRMCoA5SyBbCcaRQVnBXBJQI+Ri7Tb/XIOtnsOQuAUVD9JZAbRYXQqUvXXolB/CMkJTaOHh1HE4beS", + "cEiiw1/0qNdNl5rQuh+bfoFYRg+j6KhMCDu3VHZJOffoRzPGEVY9UAoUOFatukzpFvpHlv04iw5/+Rr9", + "N4dZdBj9124jzV0ryt1TSAi+Oj+JHq5HAUnYmSAxM+90uDXTufx6PAWYfsOSxSQFqhtesku4l4rcHi58", + "kq6KjOGkogbNSAZIMjQFJDmmquUUEiWTGeM5ltFhNCUU80XUoq+rxFGUg8QJltjMOsNlpvp/fYjacjlK", + "EqJ+4gx9YVNEqJmMMGppKbAQkKg/5BxQQQrICPXtqJorRIdS9oQkPh0dKj6WaUpoin7AcWUgx+9QqSZW", + "hlLJo6ispJ7aNE1CU3OQJacTSXIQEueF8GmQvIQOHee6D2r6mOnnnkqQhHu5gy7KomBcWdMtzkoQh2hL", + "AJVAY9gaoa07xpOtEVJmjgxRaMpYBpii7S01+ZZ6tzXDmYCtVzvonaEMEYHs6+1mvFc7VUuUA6YCUeYQ", + "uWNns+/U7/EUa601bRypWS4vG8msgoGOY4Tsfol7HOc4hUum/+n6R1qSBNMYJiLGGXhq+n7ndVtH72nM", + "So5TENZSZI0hgEiuX8QZE5AtUEboTWO8Sm+o4CwvJNqek3QO3OoO5XiBOCRlbIdAv5U4I3LxypXbB0sn", + "utB01vzSMp8CV/ySisEeTzdjS6YoJ7MFuiNy3vGrfnc38gvYuh53skSO+105voOUgybmbk5iQ0aDkIZS", + "IlBRirkW4R3midCtCCWS4My02WnTh1aLKWMcixWQcIRO2PkR2j5hd+NzTG/QUYILqZHplVU8pgkiUqCY", + "cRMdE+Vld0DSudSOa5hwAgx6f4/zIoND9BX9GmVYApXjmFFBhHK0xW4W52NF3Vgk99mv0SHa39kboV8j", + "Cpx8EbsFuYdsjLkcV28PHlwBnGjGng0HO/wMhEIKKZbkFibG+FcQcdm4ybZ4pd2rJAmguzmW6i+4j7My", + "ATTjLA+I+DiljCsLmiHfINGv5d7eNzHad8n+ZElDZ4a0EPVlPjF+PSmAh3jYb7PwSZsaYrMKEFyMKIBb", + "9jxCyhwdm8ZnwDvkECohNdar6aEz4KBZk9AKLft7e/30JEAZEUrHuuMOOmUczG9UihJnCrUAa8yyEGWh", + "qGJlWkokMnYHHNVUqGGSMtOeO12oeAM0lfMOf1V7dKGpDnHnineIVSyzyX6dCjwDuZjEc4hvPOGp0NeW", + "3hlwhYkqkOpuSHfTpigkyTXuz9rYpWChzBKVwrDZDKhQRsY4mmOez8rMJfPCjPpWE1MTa6O1phYg6Urk", + "AqxbckwTliODbz2iUI2D8q505Ulhb+d/euCazUwq0qRpuCgy0gQ5DpWOjWa299SbfS+QXVRzdrC5FfeL", + "SoEmsAUSAC+yr84Awgny4LBZs/5kkfMJE9RaJUNh+Q+hcf+UfV7X0u0qlQ7M6f5FEmBdlc5aoPhdaEE2", + "4zgHoQFZQMxoos3by0Nu1fAudz/04NZch31vztffB2c1LRGhSIdzMWDSj2bw0LyDbbeOP9iMr+Pnn2q1", + "hoz104mcqdaTaRnfgGxTsX/wfZuMq2pCpWK92lREKZHjnJVUKgWYMevllptQaJ2ZUKheWZhVP3MVO23P", + "O5JlCuwJ1a86Kjw1zd5ooj3G3NDOiIAJLtNJDyzvHXTy1JoF3RnhJGnA2GPYpMvoo7fwsIsODgLyaabT", + "5t6+JuGlMQcsKr69EK8JOCpT1A/wq9OXg9f/j7OXTV5RSeKOJC3r3d87+DaEh7rlWnD4WY/dnXXNCGNC", + "x5IQcwFpDlQe0YWcE5oedMPMlN0HNk1Rpg0IfYsw53iBUnILFGGBMJqy+2oLwPqZxsWR4v+nn3/6GRk0", + "drl9w+5719zdyY8rvBeG+MciPBY3E0KLUgb5Y3djDoJlpQY11Rjpxi2m5KIgsfZKvVjDqOBwS1gp1I+E", + "xLo3kdauRk1Wpf1i//7j/We0/fEfn/9x8Po7bZIXR6deJnmqZj7WZP7lVr15mSkvFjcTVspakEvw4Fjl", + "1iWMGgmaqMLtruBcJeBqQLMtiPMpSUslTCN6Y1ZihNhMAlV/JmWs9/1ASuC2p5xjqhCH0DQDRw0eVxXl", + "6EdDeQg8qDKqjPwOk5gxnoj12CsYoRLpnoRiCaIOoPW4zZIC0xTQL3uj/WtrIrq3nRfBfQGxNM2nYBpw", + "EOqhemTUl5BcYSWjwo9Ydi701vAQYtSdrOsMn+4PrJezmeXKKqLlC3dz4IAAx5Z8RJTi0PZPo59fNejn", + "JdK6WZsyJ3/XhGV4ClmAsBP9vM5oPNIqavYRoQmJtfyxagopZyVNbGsV7/e8JlMc37hNuuSaaZdsiGcs", + "JXINazHdBCrpWHmAmLNMZTjaPM1YiFAhVdRnM0Wixjj9PrDpfGJm7+p5aOzoxIQl8eOqqHdCH7ngfOJ9", + "2qcBxNKwlTx+P3BFCvj96/+gDaxB0tzsZK3KONfeOaqcM+C/b+clvQnlPbF6oRNUpUztlbg55OqeH0u7", + "3dRNevUANtPVo7os+lsfja7rmXrGrF53BiYSckXQgzNHPVY9kQ5jHUlKt6EizJGlEVRAgh/Ort6yvCgl", + "HNNZ4Oz5tD6ET0Biosz/w9kVik0f9xi4K1QDXzXWhXMv/MWUMzSJ4he3bMDx1xxyxheTGQfwOujH6Af1", + "eEk3ySTOAv0u9fNgR0JbpOkHwe0AnHs0fVJ/r9xXUwKhpqVHpM9qJaOKIEerLeWF1XslSUZ+1ypapWKl", + "2bJpjoTEkghJYvFI5b6wxoapYRQ5PE6sJbvdHIkhK9/gdO4whua+UQzlAxbCQ00ixEGQIN9a2rYQsJiP", + "l5dnPTVG6tXAIiMDFsMLcup6oW5BzrsKd8zMHuK05WendZhu2Onh9V84I4kerua6j5UKnJdy0h7PQXLD", + "SQjGXWrbA4Toxjy5w1x7vZXFoDoq5d9LETstSpXvmUqquuDnzGuzjPkWIDmcfShK1GdybrI76DjE5tZu", + "+7Pq2SrULZqG9byjhnHXdAJSXqKMC4mlGKQGDjgbq2itFbIMZBVZohr3kQpp+3xLKYbsv65WDPsBtfQS", + "/hFwJudvqzzbl6garhThtGyuOyLTpErNHMqAlrki9sd/RqPo/fn5j+fRKDp+d/LeJe/CTLCKYUuHy5dD", + "doArvVRcq2oytEwJrG57hNFeW7jZ6eraStfKTLHEquzV0nLd6resptI5f1xLMDqlXyaX/tVAIxVd27dy", + "KdBO0VtpeYiDAKMnJ6enpiy3a9AxoxKodL3urX0U2tphmeeg5yxb7ZzcNKpmcuh3CAuTfQ6/lSACJ/k5", + "vp9IdgO0fab0nbtnfI8uTZtwmqhnFoNDs0Ptg1uVaodpG6iFos4WjAeGwZW+5IBzr5+uyfTrKnAeXHhL", + "yAtlYiWH1pHi966xNY0CJ3eSFRN/j2G873RmBfpnUKKqX9EumHK7na0sBKmV4ltJZQd9VtJ4b8u6OShv", + "86zbPgqesQ9ZlFRKXalH7pDV7E623dRZ9WtbnZTCJ9iYMLoSQ7ZDuDO+M9zILA1yS2ollraQl0AIuQV9", + "jGZP084wx8ZjunDCWTYpebZi7/Hq/ESjqSinut6Z0BTdEowuOYlv9IkFkyxmmd2JTPTGtD10z8itPXkf", + "SzZulxegQhPnoutbQxa64kFNwa1y9DWILsppRsRc0Wz69pNeJQBVXoBpgjKWeuS9N2P0UDdwM1etVttZ", + "B5IM8ZJ25aZemB9f2HQHfWKSxICkrrycE4GIQGodmaBq8qquoqpgNZtiTM6BI85KCWKkN1yJRAkDgSiT", + "pkRMzYRRcKfZlFzAPY6lebYtXqEECqCJQIz6nJC8yCAHKm01Gk1QrutrprqEYUbSkuNpBloTque/jRn8", + "G2GeltUJ6aBMtLbtWtpfHzrHFrY2WTcGCdw5gQncXbD+EvAsa0qN8fmLVyEJNRwr27PqZaVMmdkaV4FA", + "idgO481pHvWZVeV4EJ76gpU8BndWQmOW+7PWYyDpHdxf1M+Dk7eTWY8SXyQuRIUxaABarZXeLQeXbra3", + "PuRpNKtKk5ZP94wY5gKvZBWOLYOswYj1tDaNJPvzrPqljXpZED7F4kasZcumb1Ui0mPA7kFte8HC8d0I", + "ldQ5q28qCQTaNl1f1dCnSw/8iwr+MaxfeLJy7dgZT4sgqPeY8b61qJbHljBhItHHeqa5pluf1PtTelBm", + "Bl55gdESJqrmVqrXLdqX6lcvfQPHUbl6USlT4Q0mplrPuUWHp6yUrYIq3a+rcCpmd91pPs9BVrWPZsI7", + "LNAsw2kKCcICfbr44bN3UKqGGX74pzSh3pjzZbdQtZ5xUMFZ0K/V4MqpTblDw0KMqUoQcByDEOaWY73B", + "PsCJjesKQ4oWm6tPra4+PV6dn4RUqdGXs9xehuql0tfYS/Pc5lIxE2D06bdO9EmmGLJ5Yg49h+8rmSPM", + "h9aZZmjZ/rzbN6OKx2u/9zJgUO/t3Ye+hdff527jUxbYd24OLimw31wW3FwW/PteFnz9H31XEF2AWqhL", + "QLpYtzCbNLp4U+9jbP3vljINUV+1ny6aks5NfdafdiOgg98DbwRYg2mFWD+E9sbZiwIgnvcFWo8LF7KO", + "UK7wRBSAb4CjBNTKngul40yBf7ZAcF9wEFpvKkxgqlWdqD4Qz6s6M2V02lbV40S3LIiMted0ltLVX0p2", + "1dRqCSsBbLql/jLjh/XoDPKMFxeHULIsWjRJ2fIQYUrk9f7Gsql68zXfXjxTCBjMyuqPjMXe+RKmC1vO", + "0ubwa8emrx/cGB63KgiaUybzQZ7WeVRQhvpB01TTjC7V01Wpq+LDTGVbOq41oOJk/V241ftu5lLnqkS9", + "ugKp2nprhTUPo9trhOqWqCFixeG0JdWV2fK9Ho3QccmJXFwoUgyfHy8vz94A5sDr7zppWDeP6kHmUhbR", + "gxqDBAv2juxd7rj+/A4vKTo6rvf93I2+E3ILhcKSo2N0XlKqJ1K4Zsba29nb2VMCYQVQXJDoMPpmZ39n", + "T2kLy7kme1d/1WUs2bhy4oKJUDSvP33jfKnI3K6wqy1WWGs4TtRSov1ZGG5OCd+wZNE62jZRH3O5q8Lu", + "uPpikVHzKiMIfYPmwVexivHOiZ9m+2Bvr0WFI/XdL8LEj2EkeAtEPXcrcJd6sT8rM9Q0G0XfPiEJTSVc", + "YP43OEHVGa2ed/9l5r2iuJRzxsnvkOiJ9795mYkts+g9lSoNvmQMnWBuKgO+3X/9Utw3CatGKoPlioSD", + "gycloVOV2CWmaYLqysXXL2V/x1QCpzhDF8BvgVcUODCqY64LoL9cP1yPIlHmOeaL6hNn6JKhKjXAqVDY", + "XYUShd73Y5NiYbEYU5zDmN0C5yTRyO+hwyjandtCs90KhVPQIvBBzK0SjJ4RQULViEOB5MGVUzWQKcf0", + "Oa1rDZeyWlXePTuvZqI/xmU1hmJTV9j1s2dePydfTonf47gyJGpu9NJKBeX6jlo4Kh8VRbaoLqp53wIR", + "5mi/4EwlWc5irROmWx9veeY47c32woHaLzrcROr+SL2JUOtGKHPj/5Kh+trnmiGK+I7hgsCAzFxvWBkc", + "WJ2Y+9/2eRmH/zMS81AF7sbr/+L5+QZ6Hg09j0yOieehLvDc1p/1CiLPh9DHrNZKOqqPv7wMBpnZXhiE", + "/M2kDfxsko5n8Pz6I0qPc/3KMUbRbkZuYexXPK5afgQXHk41s6ndcz9OKUtOIUFAE/39EhGEiHbx3VKY", + "eLyOegpXXxgleisNN4CxAYynAwxlZgYs/ghqZG3PNMiR5QNSBX3WWOp6BowyTNNSQVh9lN9FgZPT53L8", + "5ubSSzu7c51n498b/35C/9besrY/Z7lxYVuKPsb2u1rjg36Ptp/gsoXP+vYXpksy/sAnu5456+/M+MJu", + "7peUbxx94+hP5+iV91XGjQ4e4fei6yCjaFdF6AFHDx9aFcl67e8UIIeTeqfS65nCereWbHPKsHH7v4nb", + "6yq6P3DIIB3385zd1OMN2urzu7j/b5n576aqu8HVJqBsKv8wTZwSTO8/8+pBClPj96xQ4ZURvjBW+P+1", + "3AYrNljx9FhRu9DjwMJ212hROp/SDcKE/ZxnvRJA00X1fxXoK5FSoOaL5UG3bz4I+syrg2qiTXaw8fi/", + "icc7H9Nd09VL1xmEJkDo6VpfM6/qjd9mrEzQW5bnJSVygT5gCXd4EdkLwLrKWRzu7iYccD5OzdudzHbf", + "iVV3XVbfM/6F1FlF37D1QEK328UF2Z2CxLs1vw/XD/8XAAD//+pb7T5rdwAA", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 7f0ae21e..c1b0c639 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -394,10 +394,11 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return resp.JSON200, nil } -func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interface{}, error) { +func (w *Worker) LLM(ctx context.Context, req GenLLMJSONRequestBody) (interface{}, error) { isStreaming := req.Stream != nil && *req.Stream - borrowCtx, cancel := context.WithCancel(context.Background()) - c, err := w.borrowContainer(borrowCtx, "llm", *req.ModelId) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + c, err := w.borrowContainer(ctx, "llm", *req.Model) if err != nil { return nil, err } @@ -408,17 +409,10 @@ func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interf return nil, errors.New("container client is nil") } - slog.Info("Container borrowed successfully", "model_id", *req.ModelId) - - var buf bytes.Buffer - mw, err := NewLLMMultipartWriter(&buf, req) - if err != nil { - cancel() - return nil, err - } + slog.Info("Container borrowed successfully", "model_id", *req.Model) if isStreaming { - resp, err := c.Client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf) + resp, err := c.Client.GenLLM(ctx, req) if err != nil { cancel() return nil, err @@ -427,7 +421,7 @@ func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interf } defer cancel() - resp, err := c.Client.GenLLMWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + resp, err := c.Client.GenLLMWithResponse(ctx, req) if err != nil { return nil, err }