diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py index 7d3440d7b..4e240139f 100644 --- a/runner/app/pipelines/llm.py +++ b/runner/app/pipelines/llm.py @@ -1,201 +1,309 @@ import asyncio import logging import os -import psutil -from typing import Dict, Any, List, Optional, AsyncGenerator, Union +import time +from dataclasses import dataclass +from typing import Dict, Any, List, AsyncGenerator, Union, Optional -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig -from accelerate import init_empty_weights, load_checkpoint_and_dispatch from app.pipelines.base import Pipeline -from app.pipelines.utils import get_model_dir, get_torch_device -from huggingface_hub import file_download, snapshot_download -from threading import Thread +from app.pipelines.utils import get_model_dir, get_max_memory +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams +from huggingface_hub import file_download +from transformers import AutoConfig logger = logging.getLogger(__name__) -def get_max_memory(): - num_gpus = torch.cuda.device_count() - gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} - cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" - max_memory = {**gpu_memory, "cpu": cpu_memory} +@dataclass +class GenerationConfig: + max_tokens: int = 256 + temperature: float = 0.7 + top_p: float = 1.0 + top_k: int = -1 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + + def validate(self): + """Validate generation parameters""" + if not 0 <= self.temperature <= 2: + raise ValueError("Temperature must be between 0 and 2") + if not 0 <= self.top_p <= 1: + raise ValueError("Top_p must be between 0 and 1") + if self.max_tokens < 1: + raise ValueError("Max_tokens must be positive") + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError("Presence penalty must be between -2.0 and 2.0") + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError("Frequency penalty must be between -2.0 and 2.0") - logger.info(f"Max memory configuration: {max_memory}") - return max_memory +class LLMPipeline(Pipeline): + def __init__( + self, + model_id: str, + ): + """ + Initialize the LLM Pipeline. + + Args: + model_id: The identifier for the model to load + use_8bit: Whether to use 8-bit quantization + max_batch_size: Maximum batch size for inference + max_num_seqs: Maximum number of sequences + mem_utilization: GPU memory utilization target + max_num_batched_tokens: Maximum number of batched tokens + """ + logger.info("Initializing LLM pipeline") -def load_model_8bit(model_id: str, **kwargs): - max_memory = get_max_memory() - - quantization_config = BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - ) - - tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - model = AutoModelForCausalLM.from_pretrained( - model_id, - quantization_config=quantization_config, - device_map="auto", - max_memory=max_memory, - offload_folder="offload", - low_cpu_mem_usage=True, - **kwargs - ) - - return tokenizer, model - - -def load_model_fp16(model_id: str, **kwargs): - device = get_torch_device() - max_memory = get_max_memory() - - # Check for fp16 variant - local_model_path = os.path.join( - get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) - has_fp16_variant = any(".fp16.safetensors" in fname for _, _, - files in os.walk(local_model_path) for fname in files) - - if device != "cpu" and has_fp16_variant: - logger.info("Loading fp16 variant for %s", model_id) - kwargs["torch_dtype"] = torch.float16 - kwargs["variant"] = "fp16" - elif device != "cpu": - kwargs["torch_dtype"] = torch.bfloat16 - - tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) - - config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) - - checkpoint_dir = snapshot_download( - model_id, cache_dir=get_model_dir(), local_files_only=True) + self.model_id = model_id + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model") + base_path = os.path.join(get_model_dir(), folder_name) - model = load_checkpoint_and_dispatch( - model, - checkpoint_dir, - device_map="auto", - max_memory=max_memory, - # Adjust based on your model architecture - no_split_module_classes=["LlamaDecoderLayer"], - dtype=kwargs.get("torch_dtype", torch.float32), - offload_folder="offload", - offload_state_dict=True, - ) + # Find the actual model path + self.local_model_path = self._find_model_path(base_path) - return tokenizer, model + if not self.local_model_path: + raise ValueError(f"Could not find model files for {model_id}") + use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" + max_num_batched_tokens = int(os.getenv("MAX_NUM_BATCHED_TOKENS", "8192")) + max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "128")) + max_model_len = int(os.getenv("MAX_MODEL_LEN", "8192")) + mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.85")) + tensor_parallel_size = int(os.getenv("TENSOR_PARALLEL_SIZE", "1")) + pipeline_parallel_size = int(os.getenv("PIPELINE_PARALLEL_SIZE", "1")) + + if max_num_batched_tokens < max_model_len: + max_num_batched_tokens = max_model_len + logger.info( + f"max_num_batched_tokens ({max_num_batched_tokens}) is smaller than max_model_len ({max_model_len}). This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences.") + logger.info(f"setting 'max_model_len' to equal 'max_num_batched_tokens'") -class LLMPipeline(Pipeline): - def __init__(self, model_id: str): - self.model_id = model_id - kwargs = { - "cache_dir": get_model_dir(), - "local_files_only": True, - } - self.device = get_torch_device() - - # Generate the correct folder name - folder_path = file_download.repo_folder_name( - repo_id=model_id, repo_type="model") - self.local_model_path = os.path.join(get_model_dir(), folder_path) - self.checkpoint_dir = snapshot_download( - model_id, cache_dir=get_model_dir(), local_files_only=True) + # Load config to check model compatibility + try: + config = AutoConfig.from_pretrained(self.local_model_path) + num_heads = config.num_attention_heads + num_layers = config.num_hidden_layers + logger.info( + f"Model has {num_heads} attention heads and {num_layers} layers") + + # Validate tensor parallelism + if num_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({num_heads}) must be divisible " + f"by tensor parallel size ({tensor_parallel_size})." + ) + + # Validate pipeline parallelism + if num_layers < pipeline_parallel_size: + raise ValueError( + f"Pipeline parallel size ({pipeline_parallel_size}) cannot be larger " + f"than number of layers ({num_layers})." + ) + + # Validate total GPU requirement + total_gpus_needed = tensor_parallel_size * pipeline_parallel_size + max_memory = get_max_memory() + if total_gpus_needed > max_memory.num_gpus: + raise ValueError( + f"Total GPUs needed ({total_gpus_needed}) exceeds available GPUs " + f"({max_memory.num_gpus}). Reduce tensor_parallel_size ({tensor_parallel_size}) " + f"or pipeline_parallel_size ({pipeline_parallel_size})." + ) + + logger.info(f"Using tensor parallel size: {tensor_parallel_size}") + logger.info(f"Using pipeline parallel size: {pipeline_parallel_size}") + logger.info(f"Total GPUs used: {total_gpus_needed}") - logger.info(f"Local model path: {self.local_model_path}") - logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") + except Exception as e: + logger.error(f"Error in parallelism configuration: {e}") + raise - use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" + engine_args = AsyncEngineArgs( + model=self.local_model_path, + tokenizer=self.local_model_path, + trust_remote_code=True, + dtype="auto", # This specifies BFloat16 precision, TODO: Check GPU capabilities to set best type + kv_cache_dtype="auto", # or "fp16" if you want to force it + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_num_batched_tokens=max_num_batched_tokens, + gpu_memory_utilization=mem_utilization, + max_num_seqs=max_num_seqs, + enforce_eager=False, + enable_prefix_caching=True, + max_model_len=max_model_len + ) if use_8bit: + engine_args.quantization = "bitsandbytes" + engine_args.load_format = "bitsandbytes" logger.info("Using 8-bit quantization") - self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) else: - logger.info("Using fp16/bf16 precision") - self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) + logger.info("Using BFloat16 precision") + + self.engine_args = engine_args + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + logger.info(f"Model loaded: {self.model_id}") + logger.info(f"Using GPU memory utilization: {mem_utilization}") + self.engine.start_background_loop() + + @staticmethod + def _get_model_dir() -> str: + """Get the model directory from environment or default""" + return os.getenv("MODEL_DIR", "/models") + + def validate_messages(self, messages: List[Dict[str, str]]): + """Validate message format""" + if not messages: + raise ValueError("Messages cannot be empty") + + for msg in messages: + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError(f"Invalid message format: {msg}") + if msg['role'] not in ['system', 'user', 'assistant']: + raise ValueError(f"Invalid role in message: {msg['role']}") + + async def generate( + self, + messages: List[Dict[str, str]], + generation_config: Optional[GenerationConfig] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + """Internal generation method""" + start_time = time.time() + config = generation_config or GenerationConfig() + tokenizer = await self.engine.get_tokenizer() - logger.info( - f"Model loaded and distributed. Device map: {self.model.hf_device_map}" + try: + full_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False + ) + except Exception as e: + logger.error(f"Error applying chat template: {e}") + raise + + sampling_params = SamplingParams( + temperature=config.temperature, + max_tokens=config.max_tokens, + top_p=config.top_p, + top_k=config.top_k, + presence_penalty=config.presence_penalty, + frequency_penalty=config.frequency_penalty, ) - # Set up generation config - self.generation_config = self.model.generation_config + request_id = f"chatcmpl-{int(time.time())}" - self.terminators = [ - self.tokenizer.eos_token_id, - self.tokenizer.convert_tokens_to_ids("<|eot_id|>") - ] + results_generator = self.engine.generate( + prompt=full_prompt, sampling_params=sampling_params, request_id=request_id) - # Optional: Add optimizations - sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" - if sfast_enabled: - logger.info( - "LLMPipeline will be dynamically compiled with stable-fast for %s", - model_id, - ) - from app.pipelines.optim.sfast import compile_model - self.model = compile_model(self.model) - - async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: - conversation = [] - if system_msg: - conversation.append({"role": "system", "content": system_msg}) - if history: - conversation.extend(history) - conversation.append({"role": "user", "content": prompt}) - - input_ids = self.tokenizer.apply_chat_template( - conversation, return_tensors="pt").to(self.model.device) - attention_mask = torch.ones_like(input_ids) - - max_new_tokens = kwargs.get("max_tokens", 256) - temperature = kwargs.get("temperature", 0.7) - - streamer = TextIteratorStreamer( - self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) - - generate_kwargs = self.generation_config.to_dict() - generate_kwargs.update({ - "input_ids": input_ids, - "attention_mask": attention_mask, - "streamer": streamer, - "max_new_tokens": max_new_tokens, - "do_sample": temperature > 0, - "temperature": temperature, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.eos_token_id, - }) - - thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) - thread.start() + input_tokens = len(tokenizer.encode(full_prompt)) + if input_tokens > self.engine_args.max_model_len: + raise ValueError( + f"Input sequence length ({input_tokens}) exceeds maximum allowed ({self.engine.engine_args.max_model_len})") total_tokens = 0 + current_response = "" + first_token_time = None + try: - for text in streamer: - total_tokens += 1 - yield text - await asyncio.sleep(0) # Allow other tasks to run + async for output in results_generator: + if output.outputs: + if first_token_time is None: + first_token_time = time.time() + + generated_text = output.outputs[0].text + delta = generated_text[len(current_response):] + current_response = generated_text + total_tokens += len(tokenizer.encode(delta)) + + yield { + "choices": [{ + "delta": {"content": delta}, + "finish_reason": None + }], + "created": int(time.time()), + "model": self.model_id, + "id": request_id + } + + await asyncio.sleep(0) + + # Final message + end_time = time.time() + duration = end_time - start_time + logger.info(f"Generation completed in {duration:.2f}s") + logger.info( + f" Time to first token: {(first_token_time - start_time):.2f} seconds") + logger.info(f" Total tokens: {total_tokens}") + logger.info(f" Prompt tokens: {input_tokens}") + logger.info(f" Generated tokens: {total_tokens}") + generation_time = end_time - first_token_time if first_token_time else 0 + logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}") + yield { + "choices": [{ + "delta": {"content": ""}, + "finish_reason": "stop" + }], + "created": int(time.time()), + "model": self.model_id, + "id": request_id, + "usage": { + "prompt_tokens": input_tokens, + "completion_tokens": total_tokens, + "total_tokens": input_tokens + total_tokens + } + } + except Exception as e: - logger.error(f"Error during streaming: {str(e)}") + if "CUDA out of memory" in str(e): + logger.error( + "GPU memory exhausted, consider reducing batch size or model parameters") + elif "tokenizer" in str(e).lower(): + logger.error("Tokenizer error, check input text format") + else: + logger.error(f"Error generating response: {e}") raise - input_length = input_ids.size(1) - yield {"tokens_used": input_length + total_tokens} + async def __call__( + self, + messages: List[Dict[str, str]], + **kwargs + ) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: + """ + Generate responses for messages. + + Args: + messages: List of message dictionaries in OpenAI format + **kwargs: Generation parameters + """ + logger.debug(f"Generating response for messages: {messages}") + start_time = time.time() - def model_generate_wrapper(self, **kwargs): try: - logger.debug("Entering model.generate") - with torch.cuda.amp.autocast(): # Use automatic mixed precision - self.model.generate(**kwargs) - logger.debug("Exiting model.generate") + # Validate inputs + self.validate_messages(messages) + config = GenerationConfig(**kwargs) + config.validate() + + async for response in self.generate(messages, config): + yield response + except Exception as e: - logger.error(f"Error in model.generate: {str(e)}", exc_info=True) + logger.error(f"Error in pipeline: {e}") raise def __str__(self): return f"LLMPipeline(model_id={self.model_id})" + + def _find_model_path(self, base_path): + # Check if the model files are directly in the base path + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)): + return base_path + + # If not, look in subdirectories + for root, dirs, files in os.walk(base_path): + if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files): + return root diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index 6155d0ffd..2bafb4483 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -14,4 +14,5 @@ is_turbo_model, split_prompt, validate_torch_device, + get_max_memory, ) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 714aeb4d7..7379afb66 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -4,6 +4,7 @@ import logging import os import re +import psutil from pathlib import Path from typing import Any, Dict, List, Optional @@ -365,3 +366,25 @@ def enable_loras(self) -> None: if not self.loras_enabled: self.pipeline.enable_lora() self.loras_enabled = True + + +class MemoryInfo: + def __init__(self, gpu_memory, cpu_memory, num_gpus): + self.gpu_memory = gpu_memory + self.cpu_memory = cpu_memory + self.num_gpus = num_gpus + + def __repr__(self): + return f"" + + +def get_max_memory() -> MemoryInfo: + num_gpus = torch.cuda.device_count() + gpu_memory = { + i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} + cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" + + memory_info = MemoryInfo(gpu_memory=gpu_memory, + cpu_memory=cpu_memory, num_gpus=num_gpus) + + return memory_info diff --git a/runner/app/routes/llm.py b/runner/app/routes/llm.py index 1080280c2..076efef44 100644 --- a/runner/app/routes/llm.py +++ b/runner/app/routes/llm.py @@ -1,12 +1,12 @@ import logging import os -from typing import Annotated -from fastapi import APIRouter, Depends, Form, status +import time +from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.utils import HTTPError, LLMResponse, http_error +from app.routes.utils import HTTPError, LLMRequest, LLMResponse, http_error import json router = APIRouter() @@ -33,13 +33,7 @@ ) @router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False) async def llm( - prompt: Annotated[str, Form()], - model_id: Annotated[str, Form()] = "", - system_msg: Annotated[str, Form()] = "", - temperature: Annotated[float, Form()] = 0.7, - max_tokens: Annotated[int, Form()] = 256, - history: Annotated[str, Form()] = "[]", # We'll parse this as JSON - stream: Annotated[bool, Form()] = False, + request: LLMRequest, pipeline: Pipeline = Depends(get_pipeline), token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): @@ -52,50 +46,50 @@ async def llm( content=http_error("Invalid bearer token"), ) - if model_id != "" and model_id != pipeline.model_id: + if request.model != "" and request.model != pipeline.model_id: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( - f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"pipeline configured with {pipeline.model_id} but called with {request.model}" ), ) try: - history_list = json.loads(history) - if not isinstance(history_list, list): - raise ValueError("History must be a JSON array") - generator = pipeline( - prompt=prompt, - history=history_list, - system_msg=system_msg if system_msg else None, - temperature=temperature, - max_tokens=max_tokens + messages=[msg.dict() for msg in request.messages], + temperature=request.temperature, + max_tokens=request.max_tokens, + top_p=request.top_p, + top_k=request.top_k ) - if stream: - return StreamingResponse(stream_generator(generator), media_type="text/event-stream") + if request.stream: + return StreamingResponse( + stream_generator(generator), + media_type="text/event-stream" + ) else: full_response = "" + last_chunk = None + async for chunk in generator: if isinstance(chunk, dict): - tokens_used = chunk["tokens_used"] - break - full_response += chunk + if "choices" in chunk: + if "delta" in chunk["choices"][0]: + full_response += chunk["choices"][0]["delta"].get( + "content", "") + last_chunk = chunk - return LLMResponse(response=full_response, tokens_used=tokens_used) + usage = last_chunk.get("usage", {}) + + return LLMResponse( + response=full_response, + tokens_used=usage.get("total_tokens", 0), + id=last_chunk.get("id", ""), + model=last_chunk.get("model", pipeline.model_id), + created=last_chunk.get("created", int(time.time())) + ) - except json.JSONDecodeError: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": "Invalid JSON format for history"} - ) - except ValueError as ve: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(ve)} - ) except Exception as e: logger.error(f"LLM processing error: {str(e)}") return JSONResponse( @@ -107,11 +101,12 @@ async def llm( async def stream_generator(generator): try: async for chunk in generator: - if isinstance(chunk, dict): # This is the final result - yield f"data: {json.dumps(chunk)}\n\n" - break - else: - yield f"data: {json.dumps({'chunk': chunk})}\n\n" + if isinstance(chunk, dict): + if "choices" in chunk: + # Regular streaming chunk or final chunk + yield f"data: {json.dumps(chunk)}\n\n" + if chunk["choices"][0].get("finish_reason") == "stop": + break yield "data: [DONE]\n\n" except Exception as e: logger.error(f"Streaming error: {str(e)}") diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 8868989bb..6f8271c0d 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -72,9 +72,27 @@ class TextResponse(BaseModel): chunks: List[Chunk] = Field(..., description="The generated text chunks.") +class LLMMessage(BaseModel): + role: str + content: str + + +class LLMRequest(BaseModel): + messages: List[LLMMessage] + model: str = "" + temperature: float = 0.7 + max_tokens: int = 256 + top_p: float = 1.0 + top_k: int = -1 + stream: bool = False + + class LLMResponse(BaseModel): response: str tokens_used: int + id: str + model: str + created: int class ImageToTextResponse(BaseModel): @@ -101,6 +119,7 @@ class LiveVideoToVideoResponse(BaseModel): description="URL for subscribing to events for pipeline status and logs", ) + class APIError(BaseModel): """API error response model.""" diff --git a/runner/docker/Dockerfile.llm b/runner/docker/Dockerfile.llm new file mode 100644 index 000000000..5ab4b9675 --- /dev/null +++ b/runner/docker/Dockerfile.llm @@ -0,0 +1,74 @@ +# Based on https://github.com/huggingface/api-inference-community/blob/main/docker_images/diffusers/Dockerfile + +FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu20.04 +LABEL maintainer="Yondon Fu " + +# Add any system dependency here +# RUN apt-get update -y && apt-get install libXXX -y + +ENV DEBIAN_FRONTEND=noninteractive + +# Install prerequisites +RUN apt-get update && \ + apt-get install -y build-essential libssl-dev zlib1g-dev libbz2-dev \ + libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \ + xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git \ + ffmpeg + +# Install pyenv +RUN curl https://pyenv.run | bash + +# Set environment variables for pyenv +ENV PYENV_ROOT /root/.pyenv +ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH + +# Install your desired Python version +ARG PYTHON_VERSION=3.11 +RUN pyenv install $PYTHON_VERSION && \ + pyenv global $PYTHON_VERSION && \ + pyenv rehash + +# Upgrade pip and install your desired packages +ARG PIP_VERSION=24.2 +RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 && \ + pip install --no-cache-dir torch==2.4.0 torchvision torchaudio pip-tools + +WORKDIR /app + +COPY ./requirements.llm.in /app +RUN pip-compile requirements.llm.in -o requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Most DL models are quite large in terms of memory, using workers is a HUGE +# slowdown because of the fork and GIL with python. +# Using multiple pods seems like a better default strategy. +# Feel free to override if it does not make sense for your library. +ARG max_workers=1 +ENV MAX_WORKERS=$max_workers +ENV HUGGINGFACE_HUB_CACHE=/models +ENV DIFFUSERS_CACHE=/models +ENV MODEL_DIR=/models +# This ensures compatbility with how GPUs are addressed within go-livepeer +ENV CUDA_DEVICE_ORDER=PCI_BUS_ID + +# vLLM configuration +ENV USE_8BIT=false +ENV MAX_NUM_BATCHED_TOKENS=8192 +ENV MAX_NUM_SEQS=128 +ENV MAX_MODEL_LEN=8192 +ENV GPU_MEMORY_UTILIZATION=0.85 +ENV TENSOR_PARALLEL_SIZE=1 +ENV PIPELINE_PARALLEL_SIZE=1 +# To use multiple GPUs, set TENSOR_PARALLEL_SIZE and PIPELINE_PARALLEL_SIZE +# Total GPUs used = TENSOR_PARALLEL_SIZE × PIPELINE_PARALLEL_SIZE +# Example for 4 GPUs: +# - Option 1: TENSOR_PARALLEL_SIZE=2, PIPELINE_PARALLEL_SIZE=2 +# - Option 2: TENSOR_PARALLEL_SIZE=4, PIPELINE_PARALLEL_SIZE=1 +# - Option 3: TENSOR_PARALLEL_SIZE=1, PIPELINE_PARALLEL_SIZE=4 + +COPY app/ /app/app +COPY images/ /app/images +COPY bench.py /app/bench.py +COPY example_data/ /app/example_data + +CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"] diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index e19de55fb..933be8d6a 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -319,9 +319,9 @@ paths: operationId: genLLM requestBody: content: - application/x-www-form-urlencoded: + application/json: schema: - $ref: '#/components/schemas/Body_genLLM' + $ref: '#/components/schemas/LLMRequest' required: true responses: '200': @@ -525,8 +525,7 @@ components: AudioResponse: properties: audio: - allOf: - - $ref: '#/components/schemas/MediaURL' + $ref: '#/components/schemas/MediaURL' description: The generated audio. type: object required: @@ -714,40 +713,6 @@ components: - image - model_id title: Body_genImageToVideo - Body_genLLM: - properties: - prompt: - type: string - title: Prompt - model_id: - type: string - title: Model Id - default: '' - system_msg: - type: string - title: System Msg - default: '' - temperature: - type: number - title: Temperature - default: 0.7 - max_tokens: - type: integer - title: Max Tokens - default: 256 - history: - type: string - title: History - default: '[]' - stream: - type: boolean - title: Stream - default: false - type: object - required: - - prompt - - model_id - title: Body_genLLM Body_genSegmentAnything2: properties: image: @@ -861,8 +826,7 @@ components: HTTPError: properties: detail: - allOf: - - $ref: '#/components/schemas/APIError' + $ref: '#/components/schemas/APIError' description: Detailed error information. type: object required: @@ -902,6 +866,54 @@ components: - text title: ImageToTextResponse description: Response model for text generation. + LLMMessage: + properties: + role: + type: string + title: Role + content: + type: string + title: Content + type: object + required: + - role + - content + title: LLMMessage + LLMRequest: + properties: + messages: + items: + $ref: '#/components/schemas/LLMMessage' + type: array + title: Messages + model: + type: string + title: Model + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + top_p: + type: number + title: Top P + default: 1.0 + top_k: + type: integer + title: Top K + default: -1 + stream: + type: boolean + title: Stream + default: false + type: object + required: + - messages + title: LLMRequest LLMResponse: properties: response: @@ -910,10 +922,22 @@ components: tokens_used: type: integer title: Tokens Used + id: + type: string + title: Id + model: + type: string + title: Model + created: + type: integer + title: Created type: object required: - response - tokens_used + - id + - model + - created title: LLMResponse LiveVideoToVideoParams: properties: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 757e7ae17..5b72dbb74 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -319,9 +319,9 @@ paths: operationId: genLLM requestBody: content: - application/x-www-form-urlencoded: + application/json: schema: - $ref: '#/components/schemas/Body_genLLM' + $ref: '#/components/schemas/LLMRequest' required: true responses: '200': @@ -558,8 +558,7 @@ components: AudioResponse: properties: audio: - allOf: - - $ref: '#/components/schemas/MediaURL' + $ref: '#/components/schemas/MediaURL' description: The generated audio. type: object required: @@ -748,39 +747,6 @@ components: required: - image title: Body_genImageToVideo - Body_genLLM: - properties: - prompt: - type: string - title: Prompt - model_id: - type: string - title: Model Id - default: '' - system_msg: - type: string - title: System Msg - default: '' - temperature: - type: number - title: Temperature - default: 0.7 - max_tokens: - type: integer - title: Max Tokens - default: 256 - history: - type: string - title: History - default: '[]' - stream: - type: boolean - title: Stream - default: false - type: object - required: - - prompt - title: Body_genLLM Body_genSegmentAnything2: properties: image: @@ -952,8 +918,7 @@ components: HTTPError: properties: detail: - allOf: - - $ref: '#/components/schemas/APIError' + $ref: '#/components/schemas/APIError' description: Detailed error information. type: object required: @@ -1047,6 +1012,54 @@ components: - text title: ImageToTextResponse description: Response model for text generation. + LLMMessage: + properties: + role: + type: string + title: Role + content: + type: string + title: Content + type: object + required: + - role + - content + title: LLMMessage + LLMRequest: + properties: + messages: + items: + $ref: '#/components/schemas/LLMMessage' + type: array + title: Messages + model: + type: string + title: Model + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + top_p: + type: number + title: Top P + default: 1.0 + top_k: + type: integer + title: Top K + default: -1 + stream: + type: boolean + title: Stream + default: false + type: object + required: + - messages + title: LLMRequest LLMResponse: properties: response: @@ -1055,10 +1068,22 @@ components: tokens_used: type: integer title: Tokens Used + id: + type: string + title: Id + model: + type: string + title: Model + created: + type: integer + title: Created type: object required: - response - tokens_used + - id + - model + - created title: LLMResponse LiveVideoToVideoParams: properties: diff --git a/runner/requirements.llm.in b/runner/requirements.llm.in new file mode 100644 index 000000000..bd1480880 --- /dev/null +++ b/runner/requirements.llm.in @@ -0,0 +1,22 @@ +vllm==0.6.5 +diffusers +accelerate +transformers +fastapi +pydantic +Pillow +python-multipart +uvicorn +huggingface_hub +xformers +triton +peft +deepcache +safetensors +scipy +numpy +av +sentencepiece +protobuf +bitsandbytes +psutil diff --git a/runner/requirements.txt b/runner/requirements.txt index 17237444e..12107f9b9 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -21,4 +21,4 @@ bitsandbytes==0.43.3 psutil==6.0.0 PyYAML==6.0.2 nvidia-ml-py==12.560.30 -pynvml==12.0.0 +pynvml==12.0.0 \ No newline at end of file diff --git a/worker/multipart.go b/worker/multipart.go index 551b0af84..bc70ba8f4 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -259,56 +259,6 @@ func NewAudioToTextMultipartWriter(w io.Writer, req GenAudioToTextMultipartReque return mw, nil } -func NewLLMMultipartWriter(w io.Writer, req BodyGenLLM) (*multipart.Writer, error) { - mw := multipart.NewWriter(w) - - if err := mw.WriteField("prompt", req.Prompt); err != nil { - return nil, fmt.Errorf("failed to write prompt field: %w", err) - } - - if req.History != nil { - if err := mw.WriteField("history", *req.History); err != nil { - return nil, fmt.Errorf("failed to write history field: %w", err) - } - } - - if req.ModelId != nil { - if err := mw.WriteField("model_id", *req.ModelId); err != nil { - return nil, fmt.Errorf("failed to write model_id field: %w", err) - } - } - - if req.SystemMsg != nil { - if err := mw.WriteField("system_msg", *req.SystemMsg); err != nil { - return nil, fmt.Errorf("failed to write system_msg field: %w", err) - } - } - - if req.Temperature != nil { - if err := mw.WriteField("temperature", fmt.Sprintf("%f", *req.Temperature)); err != nil { - return nil, fmt.Errorf("failed to write temperature field: %w", err) - } - } - - if req.MaxTokens != nil { - if err := mw.WriteField("max_tokens", strconv.Itoa(*req.MaxTokens)); err != nil { - return nil, fmt.Errorf("failed to write max_tokens field: %w", err) - } - } - - if req.Stream != nil { - if err := mw.WriteField("stream", fmt.Sprintf("%v", *req.Stream)); err != nil { - return nil, fmt.Errorf("failed to write stream field: %w", err) - } - } - - if err := mw.Close(); err != nil { - return nil, fmt.Errorf("failed to close multipart writer: %w", err) - } - - return mw, nil -} - func NewSegmentAnything2MultipartWriter(w io.Writer, req GenSegmentAnything2MultipartRequestBody) (*multipart.Writer, error) { mw := multipart.NewWriter(w) writer, err := mw.CreateFormFile("image", req.Image.Filename()) diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 9b6fa10f2..ee0ee96fb 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -41,7 +41,7 @@ type APIError struct { // AudioResponse Response model for audio generation. type AudioResponse struct { - // Audio The generated audio. + // Audio A URL from which media can be accessed. Audio MediaURL `json:"audio"` } @@ -144,17 +144,6 @@ type BodyGenImageToVideo struct { Width *int `json:"width,omitempty"` } -// BodyGenLLM defines model for Body_genLLM. -type BodyGenLLM struct { - History *string `json:"history,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - ModelId *string `json:"model_id,omitempty"` - Prompt string `json:"prompt"` - Stream *bool `json:"stream,omitempty"` - SystemMsg *string `json:"system_msg,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` -} - // BodyGenSegmentAnything2 defines model for Body_genSegmentAnything2. type BodyGenSegmentAnything2 struct { // Box A length 4 array given as a box prompt to the model, in XYXY format. @@ -237,7 +226,7 @@ type GPUUtilizationInfo struct { // HTTPError HTTP error response model. type HTTPError struct { - // Detail Detailed error information. + // Detail API error response model. Detail APIError `json:"detail"` } @@ -281,8 +270,28 @@ type ImageToTextResponse struct { Text string `json:"text"` } +// LLMMessage defines model for LLMMessage. +type LLMMessage struct { + Content string `json:"content"` + Role string `json:"role"` +} + +// LLMRequest defines model for LLMRequest. +type LLMRequest struct { + MaxTokens *int `json:"max_tokens,omitempty"` + Messages []LLMMessage `json:"messages"` + Model *string `json:"model,omitempty"` + Stream *bool `json:"stream,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float32 `json:"top_p,omitempty"` +} + // LLMResponse defines model for LLMResponse. type LLMResponse struct { + Created int `json:"created"` + Id string `json:"id"` + Model string `json:"model"` Response string `json:"response"` TokensUsed int `json:"tokens_used"` } @@ -449,8 +458,8 @@ type GenImageToVideoMultipartRequestBody = BodyGenImageToVideo // GenLiveVideoToVideoJSONRequestBody defines body for GenLiveVideoToVideo for application/json ContentType. type GenLiveVideoToVideoJSONRequestBody = LiveVideoToVideoParams -// GenLLMFormdataRequestBody defines body for GenLLM for application/x-www-form-urlencoded ContentType. -type GenLLMFormdataRequestBody = BodyGenLLM +// GenLLMJSONRequestBody defines body for GenLLM for application/json ContentType. +type GenLLMJSONRequestBody = LLMRequest // GenSegmentAnything2MultipartRequestBody defines body for GenSegmentAnything2 for multipart/form-data ContentType. type GenSegmentAnything2MultipartRequestBody = BodyGenSegmentAnything2 @@ -628,7 +637,7 @@ type ClientInterface interface { // GenLLMWithBody request with any body GenLLMWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) - GenLLMWithFormdataBody(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + GenLLM(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) // GenSegmentAnything2WithBody request with any body GenSegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -767,8 +776,8 @@ func (c *Client) GenLLMWithBody(ctx context.Context, contentType string, body io return c.Client.Do(req) } -func (c *Client) GenLLMWithFormdataBody(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewGenLLMRequestWithFormdataBody(c.Server, body) +func (c *Client) GenLLM(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenLLMRequest(c.Server, body) if err != nil { return nil, err } @@ -1088,15 +1097,15 @@ func NewGenLiveVideoToVideoRequestWithBody(server string, contentType string, bo return req, nil } -// NewGenLLMRequestWithFormdataBody calls the generic GenLLM builder with application/x-www-form-urlencoded body -func NewGenLLMRequestWithFormdataBody(server string, body GenLLMFormdataRequestBody) (*http.Request, error) { +// NewGenLLMRequest calls the generic GenLLM builder with application/json body +func NewGenLLMRequest(server string, body GenLLMJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader - bodyStr, err := runtime.MarshalForm(body, nil) + buf, err := json.Marshal(body) if err != nil { return nil, err } - bodyReader = strings.NewReader(bodyStr.Encode()) - return NewGenLLMRequestWithBody(server, "application/x-www-form-urlencoded", bodyReader) + bodyReader = bytes.NewReader(buf) + return NewGenLLMRequestWithBody(server, "application/json", bodyReader) } // NewGenLLMRequestWithBody generates requests for GenLLM with any type of body @@ -1338,7 +1347,7 @@ type ClientWithResponsesInterface interface { // GenLLMWithBodyWithResponse request with any body GenLLMWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) - GenLLMWithFormdataBodyWithResponse(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) + GenLLMWithResponse(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) // GenSegmentAnything2WithBodyWithResponse request with any body GenSegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSegmentAnything2Response, error) @@ -1775,8 +1784,8 @@ func (c *ClientWithResponses) GenLLMWithBodyWithResponse(ctx context.Context, co return ParseGenLLMResponse(rsp) } -func (c *ClientWithResponses) GenLLMWithFormdataBodyWithResponse(ctx context.Context, body GenLLMFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) { - rsp, err := c.GenLLMWithFormdataBody(ctx, body, reqEditors...) +func (c *ClientWithResponses) GenLLMWithResponse(ctx context.Context, body GenLLMJSONRequestBody, reqEditors ...RequestEditorFn) (*GenLLMResponse, error) { + rsp, err := c.GenLLM(ctx, body, reqEditors...) if err != nil { return nil, err } @@ -2982,86 +2991,87 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+x9a2/ctrb2XyH0voATYMa3Nu2Bgf3BSdPEOHZq+LLbIjVmc6Q1GsYSqU1SHk9z/N8P", - "uChpSImai2u7Pd3zKWOJl3V91iK5xHyNYpEXggPXKjr6Gql4CjnFn8fnJ++lFNL8TkDFkhWaCR4dmTcE", - "zCsiQRWCKyC5SCDbjQZRIUUBUjPAMXKVdrtfTaHqnoNSNAXTTzOdQXQUnanU/DUvzB9KS8bT6OFhEEn4", - "d8kkJNHRZxz1ZtGlIbTpJ8ZfINbRwyA6LhMmLioqu6RcePSTiZCEmh4kBQ6SmlZdprAF/siynybR0eev", - "0f+XMImOov+3t5DmXiXKvTNIGL2+OI0ebgYBSVQzQWJn3u1wa6dz+fV4CjD9ViTzUQocG16JK7jXhtwe", - "LnySrotM0KSmhkxYBkQLMgaiJeWm5RgSI5OJkDnV0VE0ZpzKedSir6vEQZSDpgnV1M46oWVm+n99iNpy", - "OU4SZn7SjHwRY8K4nYwJXtFSUKUgMX/oKZCCFZAx7ttRPVeIDqPsEUt8OjpUfCzTlPGU/Ejj2kBOfiCl", - "mdgYSi2PoraSZmrbNAlNLUGXko80y0FpmhfKp0HLEjp0XGAfsuhjp596KiEa7vUuuSyLQkhjTXc0K0Ed", - "kR0FXAOPYWdAdmZCJjsDYsycWKLIWIgMKCevdszkO+bdzoRmCnZe75IfLGWEKVK9frUY7/Vu3ZLkQLki", - "XDhE7lazVe/M7+GYotYWbRypVVxeLSSzCgY6jhGy+yXucZLTFK4E/tP1j7RkCeUxjFRMM/DU9P3um7aO", - "3vNYlJKmoCpL0Q2GAGE5vogzoSCbk4zx24XxGr2RQoq80OTVlKVTkJXuSE7nREJSxtUQ5N8lzZiev3bl", - "9qGik1winQ2/vMzHIA2/rGawx9Pt2FoYytlkTmZMTzt+1e/uVn4BW8dxR0vkeNCV4w+QSkBiZlMWWzIW", - "CGkpZYoUpZqiCGdUJgpbMc40o5lts9umj6wWUyYkVSsg4Ziciotj8upUzIYXlN+S44QWGpHpdaV4yhPC", - "tCKxkDY6JsbLZsDSqUbHtUw4AYa8v6d5kcER+Up+izKqgethLLhiyjjafC+L86GhbqiS++y36Igc7O4P", - "yG8RB8m+qL2C3UM2pFIP67eHD64ATpGxZ8PBDj9rQiGHlGp2ByNr/CuIuFq4ySv1Gt2rZAmQ2ZRq8xfc", - "x1mZAJlIkQdEfJJyIY0FTYhvkOS3cn//m5gcuGR/qkgj55a0EPVlPrJ+PSpAhng4aLPwCU2NiEkNCC5G", - "FCAr9jxCypyc2MbnIDvkMK4htdaL9PAJSEDWNLRCy8H+fj89CXDBlNExdtwlZ0KC/U1KVdLMoBZQxKwK", - "oiooqlkZl5qoTMxAkoYKM0xSZui547mJN8BTPe3wV7cnl0h1iDtXvOtYxTKb7NepohPQ81E8hfjWE54J", - "fW3pnYM0mGgCKXYj2A1NUWmWI+5P2thlYKHMEpPCiMkEuDJGJiSZUplPyswl89KO+g6JaYitojVSC5B0", - "JXIJlVtKyhORE4tvPaIwjYPyrnXlSWF/97964FpMbCqySNNoUWRsEeQk1Dq2mnm1b94ceIHssp6zg82t", - "uF/UCrSBLZAAeJF9dQYQTpDXDpsN608WOZ8wQW1Usi4s/yE07p+yz+taul2l0jVzun+yBERXpZMWKH4X", - "WpBNJM1BISAriAVP0Ly9POTODO9y92MPbk0x7Htzvvk+OKttSRgnGM7VGpN+tIOH5l3bdpv4Q+34GD//", - "VKu1ZGyeTuTCtB6Ny/gWdJuKg8Pv22Rc1xMaFeNq0xBlRE5zUXJtFGDHbJZbbkKBOrOh0LyqYNb8zE3s", - "rHrOWJYZsGccX3VUeGabvUWiPcbc0C6YghEt01EPLO8fdvLUhgXsTGiSLMDYY9imy+Sjt/CoFh0SFOTj", - "DNPm3r424eWxBKpqvr0QjwQclynpB/jV6cvhm//D2cs2r6glMWNJy3oP9g+/DeEhttwIDn/Gsbuzbhhh", - "bOhYEmJOT8+6kWXKlBZy7kPf5xsXrasWIeii9yMtboG3bf47BynoPbmybUKC7cXezUL+GjmylkBzbxrc", - "A/LzOJqHTWuuNOSjZlc4QOclNiHBbeBBpCEvjP5LCS0M/H4xxJXTaM1cMmAORs1LrOAS0hy4PuZzPWU8", - "PeyaxFjcB7bOSYYwQr4lVEo6Jym7A06oIpSMxX29EVShLWp1YLzgl19/+ZXYmOza/Ftx37vz0p38pI76", - "yhL/2DhP1e2I8aLUQf7EbChBiazE0GYaE2zcYkrPCxYjNuOSnZJCwh0TpTI/EhZjb6YrdBkscmtEx4P7", - "j/c/k1cf//HzPw7ffIfAdHl85q0nzszMJ0jmX27vIy8zg+XqdiRK3QhySVQ4MSusEgYLCdrcQlZ7w1Oz", - "DDMD2s1hmo9ZWhphWtFbs1IDIiYauPkzKWPc/QWtQVY99ZRyE3cYTzNw1OBxVVNOfrKUh/ycG6PK2O8w", - "ioWQidqMvUIwrgn2ZJxqUE0a1Yy7WFhSngL5vD84uKlMBHtX8xK4LyDWtvkYbAMJyjw0j6z6EpabiCm4", - "8vOWai7yzvIQYtSdrOsMn+4PKy8Xk4qrShEtX5hNQQIBGlfkE2YUR179Mvj19SIGesspbNamzIF0JCyj", - "Y8gChJ3i8yav9UirqTkgjCcsRvlT0xRSKUqeVK1N1rfvNRnT+NZt0iXXTrvkWCQTKdMbWIvtpkjJh8YD", - "1FRkJs9F87RjEcaVNrmfmBgSEePwfeDo4dTO3tXzuhlEJyYsiR/XRbMf/shthyferX8aQCwtW8njd4VX", - "LAS+f/MftI25ljS3+5mr1h0b7x/Wzhnw33fTkt+G8p7YvMBlilEmeiVdHHV2qwh0tenYXfrgANV6B0d1", - "WfQ3wJzMuJ6pZ8z6dWdgpiE3BD04czRjNRNhGOtIUrsNDWGOLK2gAhL8cH79TuRFqeGETwIVCGdNKUYC", - "mjJj/h/Or0ls+7jFAF2hWvhqsC6ce9EvtqhlkSh+cYtH3FUV5ELORxMJ4HXAx+RH83hJNy00zQL9rvB5", - "sCPjLdLwQXBTiOYeTZ/M3yt3V41AuG3pEemzWsuoJsjRakt5YfVea5ax31FFq1RsNFsumhOlqWZKs1g9", - "UrkvrLH11DCIHB5HlSW73RyJkUq+wencYSzNfaNYytfYDlnXJEIcBAnyraVtCwGL+Xh1dd5TaWZerVlq", - "ZsFi/bKspmqsW5b1Q407dmYPcdryq6Z1mF6w08PrP2nGEhyu4bqPlRqcl3LSHs9BcstJCMZdatsDhOim", - "MplRiV5fyWKtajrj30sROy1Kk+/Zerqm7Ovca7OM+RYgOZx9KErSZ3JusrvWDlmVW3t7ZPWzVahbLBo2", - "8w4WjLumE5DyEmVcaqrVWmqQQLOhidaokGUga8hS9biPVEjb51tKsWT/dbVi2Q+opZfwj0AzPX1X59m+", - "RM1wpQqnZVPsSGyTOjVzKANe5obYn/47GkTvLy5+uogG0ckPp+9d8i7tBKsYruhw+XLIDnCFS8WNamdD", - "y5TA6rZHGO21hZudrq6wda3Mlsysyl4rWm5a/ZZV1jqn0BsJBlP6ZXLpXw0spIIVniuXAu0UvZWWhzgI", - "MHp6euYy6BMrnTeLPZT2YM7aBM8tRqUCz7XtcQa5Vuss2qQzvjOcw5lLcogjdgd4tlMd8ZxTSa1p+czF", - "gmspslEpsxVbIdcXp6hcVY6xCJfxlNwxSq4ki29xA1VoEYus2hhJcJ+sOgnOzEIZj7CGWgzbZ96kQOJc", - "Zb+zZJFrmYVEDHfGIzYguijHGVNTQ7Pt2096jUc1TFGekEykHnnv7Rg91K25t2SS5zYIEi2ILHlXbuaF", - "/fFFjHfJJ6FZDERjOeCUKcIUMWltQurJ68P+uqzSrtGFnoIkUpQa1AD3f5gmiQBFuNC2bsnMRElw48vW", - "AcA9jbV99kq9JgkUwBNFBPc5YXmRQQ5cVyVSPCE5Fn2M8Vx9wtJS0nEGqAnT81/WDP5FqEzL+sBmrcDY", - "2HYjbbO+b+2iVgWz2Bg0SGdDOFBQX/lLwLMqU1oYn59LK8245djYXqVeUepU2J06CTQ3Iq6G8ea0j/rM", - "qnY8CE99KUoZgzsr47HI/VmbMYj2TpMvm+fBydux1aPEF4kLUWEMWgOtNoo2y8GlG3w2hzxEs7peZvl0", - "z4hhLvBqUePYMshaG7Ge1qaJFn+eVb+0US8LwmdU3aqNbNn2rU+sewzYPTdq50+Szgak5M7R4eJgU5FX", - "tuvrBvrwJNSvnvdPhfxz8JWpbGc8FEFQ77GQfakxymNH2TCR4CmDbY5048GhP6UHZXbglV/VVYSpunkl", - "1ZsW7Uv1i5l4YHc8Ny9qZRq8ocyWkDmfdtGxKHWrygf7dRXO1WTWnebnKei6IM9OOKOKTDKappAQqsin", - "yx9/9s5tzDDrn0UYTZg39rjLrZ5sZlyrCiro12Zw49T29HXBQky5SRBoHINS9tO7Zr9vDSe2rqssKSg2", - "V5+orj49Xl+chlSJ6CtFXn2h00ulr7GX5rnNpWEmwOjTr+TwYEWts5azZzDrL3PticpD64ilu8wdPPNq", - "clDzeOP3XgYM5n1VkN+38Pr7fHD3lFXfnc/ZllR9b79g237B9vf9gu3Nf/QHbOQSzEJdA8HawcJu0mAt", - "Ge5j7PzPjjEN1Xz/PZ4vKsy25SJ/Wpl6B7/XLFPvFiZ3Q2hvnL0sAOJpX6D1uHAh65jkBk9UAfQWJEnA", - "rOylMjrODPhncwL3hQSFejNhgnJUdWL6QDyty16M0aGtmscJtiyYjtFzOkvp+i8ju3pqs4TVAFW6Zf6y", - "44f16AzyjF/TrUPJsmixSMqWhwhbsYv7G8um6s3XfHvxTCFgMCsPozMReyfRlM+r0/U2h187Nn3z4Mbw", - "uHWg2aSr1fcAi3wdb4UJyhAfLJoizeTKPF2Vuho+7FRVS8e11jgA33wXbvW+m/3ScFWiXn+XZ9p6a4UN", - "z8baa4T600VLxIqzsopUV2bL93oQoeNSMj2/NKRYPj9eXZ2/BSpBNpcNIazbR80gU62L6MGMwYL1Q8fV", - "B8ZxcyeMLDk5Pmn2/dyNvlN2B4XBkuMTclFyjhMZXLNj7e/u7+4bgYgCOC1YdBR9s3uwu2+0RfUUyd7D", - "q0aGWgxrJy6ECkXz5j4W5/ocW+xdrbZEUVnDSWKWEu27SozIQem3IpnXG7PAcSIb9anUeybsDutrdKya", - "VxlB6GKUB1/FJsbjA6tQZPtwf79FhSP1vS/Kxo/1SPAWiDh3K3CXuNiflBlZNBtE3z4hCYvCnMD8b2lC", - "Lqz07bwHLzPvNaelngrJfocEJz745mUmrpgl77k2afCVEOSUytRK/eDNS3G/SFgRqSyWGxIOD5+UhE6R", - "VJeYRRPSFFK9eSn7O+EaJKcZuQR5B7KmwIFRjLkugH6+ebgZRKrMcyrn9b1b5EqQOjWgqTLYXYcSg973", - "Q5tiUTUfcprDUNyBlCxB5PfQYRDtTau6l70ahVNAEfgg5hYtRc+IIKHiqHWB5MGVUz2QrQ7zOW1Kn5ay", - "WhcCPTuvdqI/xmU9hmETC3762bOvn5Mvp+LocVxZEpEbXFqZoNx8MhOOysdFkc3r72a8CyqUPdovpDBJ", - "lrNY64Tp1o0izxynvdleOFD7NVDbSN0fqbcRatMIZT9AvhKk+QptwxDFfMdwQWCNzBw3rCwOrE7M/Qtn", - "Xsbh/4zEPFQQuPX6v3h+voWeR0PPI5Nj5nmoCzx3zV1TQeT5ELphaaOko76R5GUwyM72wiDkbyZt4Web", - "dDyD5zc3+zzO9WvHGER7GbuDoV/xuGr5EVx4ONXMtnbPvTFRl5JDQoAneJ2CCkJEu/huKUw8Xkc9hasv", - "jBK9lYZbwNgCxtMBhjEzCxZ/BDWytmda5MjyNVIFPGsssZ6BkozytDQQ1hzld1EA76paz/Hvh7PZbIh5", - "Qikz4LFI7EH6ZtmCmfKl3d/5wGfr8VuPf0KPt3e9berhWW6duipOH9Lq4p/hYb+PV3cEVaXQ+D0Y5UvW", - "AIE7hZ55HdCZ8YXd3C8y3zr61tGfztFr76uNmxw+wu9V10EG0Z6J2WscRnxo1SjjboBTkhxO853ar2fK", - "8LvVZdtzh63b/03cHuvq/sCxg3bcz3N2W6G31uaf38X977Xs/4pUfy1cbwvqRS0g5YlTlOn9n1M9SGGr", - "/p4VKrzCwhfGCv9/QNtixRYrnh4rGhd6HFhU3REtSueuzyBMVPcNNisBMp7XV+rjR5JakcWVykG3X9xY", - "+Myrg3qibXaw9fi/icc7t31u6Oql6wwKCVA4Xeu65boC+V0myoS8E3lecqbn5APVMKPzqPokGOue1dHe", - "XiKB5sPUvt3Nqu67semOhfY9419qzCr6hm0GUthujxZsbwya7jX8Ptw8/G8AAAD//5bIT20SdgAA", + "H4sIAAAAAAAC/+xde2/ctpb/KoR2ATvAjF9t2oWB+4eTpolx7dTw47ZFG8zlSGc0jCVSJSnb06y/+4Iv", + "iZSombFru93e+StjiY/z/J1D8lD5kqSsrBgFKkVy+CUR6RxKrH8enR2/45xx9TsDkXJSScJocqjeIFCv", + "EAdRMSoAlSyDYicZJRVnFXBJQI9Rirzf/XIOtnsJQuAcVD9JZAHJYXIqcvXXolJ/CMkJzZP7+1HC4bea", + "cMiSw1/0qJ/aLg2hTT82/QypTO5HyVGdEXZuqeyTch7Qj2aMI6x6oBwocKxa9ZnSLdSP/+YwSw6T/9pt", + "ZbhrBbh7ChnBV+cnPeJNb5/8gMQID29YtpjkQHXDS3YJd1LNPkBUyOBVVTCcQWbZmpECkGRoCkhyTFXL", + "KWSKxRnjJZbJYTIlFPNF0qGvr5NRUoLEGZbYzDrDdaH6f7lPRl17yTKifuICfWZTRKiZjDBqaamwEJCp", + "P+QcUEUqKAgNzcLNFaND6W5CspCOHhUf6jwnNEff49Tp+/g7VKuJld6dPCqn9GZq0zSLTc1B1pxOJClB", + "SFxWIqRB8hp6dJzrPqjtY6afBypBEu7kDrqoq4pxCRm6wUUN4hBtCaASaApbI7R1y3i2NULKapEhCk0Z", + "KwBTtL2lJt9S77ZmuBCw9WoHfWcoQ0Qg+3q7He/VjmuJSsBUIMo8InfsbPad+j2eYq21to0nNcvlZSuZ", + "VV7dc4yY3S9xj+MS53DJ9D99/8hrkmGawkSkuIBATd/uvO7q6B1NWc1xDsJaimwgARAp9Yu0YAKKBSoI", + "vW6NV+kNVZyVlUTbc5LPgVvdoRIvEIesTu0Q6LcaF0QuXvlye2/pRBeazoZfWpdT4Ipf4hgc8HQztmSK", + "cjJboFsi5z2/GnZ3I7+IretxJ0vkuN+X43eQc9DE3M5JashwcnSUEoGqWsy1CG8xz4RuRSiRBBemzU6X", + "PrRaTAXjWKyAhCN0ws6P0PYJux2fY3qNjjJcSY1Mr6ziMc0QkQKljJtglykvuwWSz6V2XMOEFy/Quztc", + "VgUcoi/o16TAEqgcp4wKIpSjLXaLtBwr6sYiuyt+TQ7R/s7eCP2aUODks9ityB0UY8zl2L09uPcFcKIZ", + "ezYc7PGzJhRSyLEkNzAxxr+CiMvWTbbFK+1eNckA3c6xVH/BXVrUGaAZZ2VExMc5ZVxZ0AyFBol+rff2", + "vkrRvk/2R0saOjOkxaivy4nx60kFPMbDfpeFj9rUEJs5QPAxogJu2QsIqUt0bBqfAe+RQ6iE3FivpofO", + "gINmTUIntOzv7Q3TkwFlRCgd64476JRxML9RLWpcKNQCrDHLQpSFIsfKtJZIFOwWOGqoUMNkdaE9d7pQ", + "8QZoLuc9/lx7dKGpjnHni3cdq1hmk8M6FXgGcjFJ55BeB8JToa8rvTPgChNVINXdkO6mTVFIUmrcn3Wx", + "S8FCXWQqhWGzGVChjIxxNMe8nNWFT+aFGfWtJqYh1kZrTS1A1pfIBVi35JhmrEQG3wZEoRpH5e10FUhh", + "b+d/BuCazUwq0qZpuKoK0gY5Dk7HRjPbe+rNfhDILtycPWzuxP3KKdAEtkgCEET21RlAPEFeO2w2rD9Z", + "5HzCBLVRybqw/IfQeHjKIa/r6HaVStfM6f5FMmB9lc46oPjNKLLSnHFcgtCALCBlNNPmHeQhN2p4n7vv", + "B3BrrsN+MOfrb6OzmpaIUKTDuVhj0g9m8Ni8a9tuE3+wGV/Hzz/Vag0ZD08nSqZaT6Z1eg2yS8X+wbdd", + "Mq7chErFerWpiFIixyWrqVQKMGM2yy0/odA6M6FQvbIwq36WKnbanrekKBTYE6pf9VR4apq90UQHjPmh", + "nREBE1znkwFY3jvo5akNC7ozwlnWgnHAsEmX0Ydg4WEXHRwElNNCp82DfU3CS1MOWDi+gxCvCTiqczQM", + "8KvTl4PX/4+zl01e4SRxS7KO9e7vHXwdw0Pd8kFw+KMeuz/rAyOMCR1LQswF5CVQeUQXck5oftAPM1N2", + "F9kDRYU2IPQ1wpzjBcrJDVCEBcJoyu7cFoD1M42LI8X/Tz//9DMyaOxz+4bdDa65+5MfO7wXhvjHIjwW", + "1xNCq1pG+WO3Yw6CFbUGNdUY6cYdpuSiIqn2Sr1Yw6jicENYLdSPjKS6N5HWrkZtVqX9Yv/uw92PaPvD", + "P378x8Hrb7RJXhydBpnkqZr5WJP5l1v1lnWhvFhcT1gtG0EuwYNjlVvXMGolaKIKt7uCc5WAqwHNtiAu", + "pySvlTCN6I1ZiRFiMwlU/ZnVqd73AymB255yjqlCHELzAjw1BFw5ytEPhvIYeFBlVAX5HSYpYzwTD2Ov", + "YoRKpHsSiiWIJoA247ZLCkxzQL/sjfY/WRPRve28CO4qSKVpPgXTgINQD9Ujo76MlAorGRVhxLJzobeG", + "hxij/mR9Z/h4d2C9nM0sV1YRHV+4nQMHBDi15COiFIe2fxr9/KpFvyCR1s26lHn5uyaswFMoIoSd6OdN", + "RhOQ5qjZR4RmJNXyx6op5JzVNLOtVbzfC5pMcXrtN+mTa6ZdsiFesJzIB1iL6SZQTcfKA8ScFSrD0eZp", + "xkKECqmiPpspEjXG6feRTecTM3tfz+vGjl5MWBI/rqpmJ/SRC84n3qd9GkCsDVvZ4/cDV6SA377+D9rA", + "Wkuam52sVRnng3eOnHNG/PftvKbXsbwnVS90gqqUqb0St4dc/eNgabeb+kmvHsBmunpUn8Vw66PVdTPT", + "wJjudW9gIqFUBN17czRjNRPpMNaTpPQbKsI8WRpBRST4/uzqLSurWsIxnUXOnk+bM/UMJCbK/N+fXaHU", + "9PGPgftCNfDVYF0898KfTXVCmyh+9qsAPH8toWR8MZlxgKCDfoy+V4+XdJNM4iLS71I/j3YktEOafhDd", + "DsBlQNNH9ffKfTUlEGpaBkSGrDoZOYI8rXaUF1fvlSQF+V2raJWKlWbrtjkSEksiJEnFI5X7whpbTw2j", + "xONxYi3Z7+ZJDFn5RqfzhzE0D41iKF9jIbyuScQ4iBIUWkvXFiIW8+Hy8mygZEi9WrNmyIDFqvqapuin", + "Kwfb3SO+JWuA5n/hgmSas4b6IZIcyC6lrTueh8jfmZEicOxT2x0gRjfm2S3m2nstiq5V3qT8dCny5lWt", + "8jbt7Lgp3DkL2ixjvgMsHmfvqxoNmY6ftK51rGFzZL/9mXu2Cj2rtmEz76hl3DediJSXKONCYinWUgMH", + "XIxV1NUKWQaWiizhxn2kQrq+21GKIfuvqxXDfkQtg4R/AFzI+VuXL4cSVcPVIp5ezXVHZJq4FMujDGhd", + "KmJ/+GcySt6dn/9wnoyS4+9O3vnkXZgJVjFs6fD58siOcKWXfA8qZowtNyKr1AFhdNcIfpa5uvjRtzJT", + "9LAqC7W0fOr0W1Yb6Z0jPkgwOjVfJpfhrL6Viq7RW5nSd1PtTnod4yDC6MnJ6amplu0bdMqoBCp9r3tr", + "H8W2aFgROOg5K1Y7JzeN3Ewe/R5hcbLP4bcaROREvsR3E8mugXbPhr7x937v0KVpE0/39Mxi7dDsUXvv", + "V5faYboGaqGot5USgGF0xS454DLop2srw/oIXEYX0BLKSplYzaFzNPitb2xto8gJnGTVJNwrGO97nVmF", + "/hmVqOpXdQuf/G5nKws6GqWEVuLsYMhKWu/tWDcH5W2BddtH0bPydRYXTqkr9cg9stpdxq6beqt3bauT", + "WoQEGxNGV2KdbQ3uje8NNzIpfmlJdWLpCnkJhJAb0Mdh9lTsDHNsPKYPJ5wVk5oXK/YQr85PNJqKeqrr", + "lgnN0Q3B6JKT9FqfPDDJUlbYHcVMbzDbw/OC3NgT9LFk426ZAKo0cT66vjVkoSse1RTcKEd/ANFVPS2I", + "mCuaTd9h0l0C4PICTDNUsDwg750ZY4C6NTdl1aqzm3UgyRCvaV9u6oX58ZlNd9BHJkkKSOoKyjkRiAik", + "1oMZcpO7+ghXiWo2t5icA0ec1RLESG+cEokyBgJRJk2pl5oJo+iOsSmdgDucSvNsW7xCGVRAM4EYDTkh", + "ZVVACVTaqjKaoVLXyUx1KcKM5DXH0wK0JlTPfxsz+DfCPK/dSedamWhj2420v9z3jh9sjbFuDBK4d5IS", + "uYNg/SXiWdaUWuMLi+qEJNRwrGzPqpfVMmdmi1sFAiViO0wwp3k0ZFbO8SA+9QWreQr+rISmrAxnbcZA", + "MjiAv2ieRyfvJrMBJaFIfIiKY9AaaPWg9G45uPSzvYdDnkYzV2K0fLpnxDAfeCVzOLYMstZGrKe1aSTZ", + "n2fVL23Uy4LwKRbX4kG2bPq6Uo8BA/YPXLsLFo5vR6im3pl7WxEg0Lbp+qqBPl1CEF44CI9TwwKSlWvH", + "3nhaBFG9p4wPrUW1PLaECROZPp4zzTXd+sQ9nDKAMjPwynuFljDhmlupfurQvlS/eukbOVYq1QunTIU3", + "mJiqO+82HJ6yWnYKo3S/vsKpmN32p/lxDtLVMJoJb7FAswLnOWQIC/Tx4vsfgwNPNcz6h3hKE+qNOSf2", + "C06bGdcqHIv6tRpcObUpW2hZSDFVCQJOUxDC3FZsNsrXcGLjusKQosXm61Ora0iPV+cnMVVq9OWstJea", + "BqkMNfbSPHe5VMxEGH36rRN9IinW2Twxh5fr7yuZo8j7ztlkbNn+vNs3I8fjp7D3MmBQ7+0dhqGF19/n", + "juJTFsr3bgAuKZTfXPrbXPr7+176e/0ffecPXYBaqEtAuui2Mps0ughT72Ns/e+WMg3RXJmfLtrSzE2d", + "1Z9W2d/D7zUr+63BdEJsGEIH4+xFBZDOhwJtwIUPWUeoVHgiKsDXwFEGamXPhdJxocC/WCC4qzgIrTcV", + "JjDVqs5UH0jnrl5MGZ22VfU40y0rIlPtOb2ltPtLyc5NrZawEsCmW+ovM35cj94gz3gBcR1KlkWLNilb", + "HiJMqbve31g21WC+FtpLYAoRg1lZ/VGwNDhfwnTxwyw5/OVLj8MvPZv+dO/H8LRTQdCeMpnv5HTOo6Iy", + "1A/apppmdKmerkpdFR9mKtvSc601Kk4evgu3et/NXM5clai7q4yqbbBWeOBhdHeN4G57GiJWHE5bUn2Z", + "Ld/r0Qid1pzIxYUixfD54fLy7A1gDrz53JKGdfOoGWQuZZXcqzFItPDuyN7JTpvP6PCaoqPjZt/P3+g7", + "ITdQKSw5OkbnNaV6IoVrZqy9nb2dPSUQVgHFFUkOk6929nf2lLawnGuyd/XXWcaSjZ0TV0zEonnzCRvv", + "i0PmloRdbbHKWsNxppYS3c+7cHNK+IZli87Rton6mMtdFXbH7stDRs2rjCD2LZn7UMUqxnsnfprtg729", + "DhWe1Hc/CxM/1iMhWCDquTuBu9aL/VldoLbZKPn6CUloK+Ei87/BGXJntHre/ZeZ94riWs4ZJ79Dpife", + "/+plJrbMondUqjT4kjF0grmpDPh6//VLcd8mrBqpDJYrEg4OnpSEXlVin5i2CWoqF1+/lP0dUwmc4gJd", + "AL8BjtryTgejOub6APrLp/tPo0TUZYn5wn2qDF0y5FIDnAuF3S6UKPS+G5sUC4vFmOISxuwGOCeZRv4A", + "HUbJ7twWmu06FM5BiyAEMb9KMHlGBIlVI64LJPe+nNxAphwz5LSpNVzKqqu8e3ZezUR/jEs3hmJTV9gN", + "s2dePydfXonf47gyJGpu9NJKBeXmrlk8Kh9VVbFwF86Cb3oIc7RfcaaSLG+x1gvTnY+wPHOcDmZ74UAd", + "Fh1uIvVwpN5EqIdGKHNz/5Kh5vrmA0MUCR3DB4E1MnO9YWVwYHViHn6j52Uc/s9IzGMVuBuv/4vn5xvo", + "eTT0PDI5JoGH+sBz03yeK4o872MfpXpQ0uE+4vIyGGRme2EQCjeTNvCzSTqewfObjyE9zvWdY4yS3YLc", + "wDiseFy1/IguPLxqZlO7539kUtacQoaAZvo7JCIKEd3iu6Uw8XgdDRSuvjBKDFYabgBjAxhPBxjKzAxY", + "/BHUKLqeaZCjKNdIFfRZY63rGTAqMM1rBWHNUX4fBU5On8vx25tLL+3s3nWejX9v/PsJ/Vt7y4P9uSiN", + "C9tS9DG238caHwx7tP2Uli181re/MF2S8Uc+vfXMWX9vxhd287CkfOPoG0d/Okd33ueMGx08wu9F30FG", + "ya6K0GscPbzvVCTrtb9XgBxP6r1Kr2cK6/1ass0pw8bt/yZur6vo/sAhg/TcL3B2U4+31lZf2MX//8fM", + "fxvl7ga7TUDZVv5hmnklmMF/yjWAFKbG71mhIigjfGGsCP+LuA1WbLDi6bGicaHHgYXtrtGi9j6JG4UJ", + "+1nOZiWApgv3fw7oK5FSoPbL41G3bz/s+cyrAzfRJjvYePzfxOO9j+I+0NVr3xmEJkDo6TpfJXf1xm8L", + "VmfoLSvLmhK5QO+xhFu8SOwFYF3lLA53dzMOuBzn5u1OYbvvpKq7LqsfGP9C6qxiaNhmIKHb7eKK7E5B", + "4t2G3/tP9/8XAAD//wP2rwUCdwAA", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index e381e4be2..6c5a2f4e5 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -394,10 +394,11 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return resp.JSON200, nil } -func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interface{}, error) { +func (w *Worker) LLM(ctx context.Context, req GenLLMJSONRequestBody) (interface{}, error) { isStreaming := req.Stream != nil && *req.Stream - borrowCtx, cancel := context.WithCancel(context.Background()) - c, err := w.borrowContainer(borrowCtx, "llm", *req.ModelId) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + c, err := w.borrowContainer(ctx, "llm", *req.Model) if err != nil { return nil, err } @@ -408,17 +409,10 @@ func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interf return nil, errors.New("container client is nil") } - slog.Info("Container borrowed successfully", "model_id", *req.ModelId) - - var buf bytes.Buffer - mw, err := NewLLMMultipartWriter(&buf, req) - if err != nil { - cancel() - return nil, err - } + slog.Info("Container borrowed successfully", "model_id", *req.Model) if isStreaming { - resp, err := c.Client.GenLLMWithBody(ctx, mw.FormDataContentType(), &buf) + resp, err := c.Client.GenLLM(ctx, req) if err != nil { cancel() return nil, err @@ -427,7 +421,7 @@ func (w *Worker) LLM(ctx context.Context, req GenLLMFormdataRequestBody) (interf } defer cancel() - resp, err := c.Client.GenLLMWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + resp, err := c.Client.GenLLMWithResponse(ctx, req) if err != nil { return nil, err }