Skip to content

Commit

Permalink
fixup! llm: use vLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Oct 17, 2024
1 parent 8a5596d commit 43b919d
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 59 deletions.
6 changes: 4 additions & 2 deletions runner/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ RUN pyenv install $PYTHON_VERSION && \
pyenv rehash

# Upgrade pip and install your desired packages
ARG PIP_VERSION=23.3.2
ARG PIP_VERSION=24.2
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 && \
pip install --no-cache-dir torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
pip install --no-cache-dir torch==2.4.0 torchvision torchaudio

WORKDIR /app
COPY ./requirements.txt /app
Expand All @@ -48,6 +48,8 @@ ENV MAX_WORKERS=$max_workers
ENV HUGGINGFACE_HUB_CACHE=/models
ENV DIFFUSERS_CACHE=/models
ENV MODEL_DIR=/models
# This ensures compatbility with how GPUs are addresses within go-livepeer
ENV CUDA_DEVICE_ORDER=PCI_BUS_ID

COPY app/ /app/app
COPY images/ /app/images
Expand Down
91 changes: 55 additions & 36 deletions runner/app/pipelines/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,59 @@
import logging
import os
from typing import Dict, Any, List, Optional, AsyncGenerator, Union

from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_max_memory
from torch import cuda
from vllm import LLM, SamplingParams
from vllm.utils import InferenceRequest
from vllm.model_executor.parallel_utils import get_gpu_memory
from vllm.outputs import RequestOutput
from huggingface_hub import file_download

logger = logging.getLogger(__name__)


class LLMPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
self.local_model_path = os.path.join(get_model_dir(), model_id)
folder_name = file_download.repo_folder_name(repo_id=model_id, repo_type="model")
base_path = os.path.join(get_model_dir(), folder_name)

# Find the actual model path
self.local_model_path = self._find_model_path(base_path)

if not self.local_model_path:
raise ValueError(f"Could not find model files for {model_id}")

use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true"
max_batch_size = os.getenv("MAX_BATCH_SIZE", "4096")
max_num_seqs = os.getenv("MAX_NUM_SEQS", "256")
mem_utilization = os.getenv("GPU_MEMORY_UTILIZATION", "0.90")
max_batch_size = int(os.getenv("MAX_BATCH_SIZE", "4096"))
max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "256"))
mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.90"))

# Get available GPU memory
max_memory = get_max_memory()
logger.info(f"Available GPU memory: {max_memory.gpu_memory}")

llm_kwargs = {
"model": self.local_model_path,
"tokenizer": self.local_model_path,
"load_format": "auto",
"trust_remote_code": True,
"dtype": "Bfloat16", # This specifies FP16 precision, TODO: Check GPU capabilities to set best type
"tensor_parallel_size": max_memory.num_gpus,
"max_num_batched_tokens": max_batch_size,
"gpu_memory_utilization": mem_utilization,
"max_num_seqs": max_num_seqs,
}

if use_8bit:
quantization = "int8"
llm_kwargs["quantization"] = "bitsandbytes" # or another supported 8-bit quantization method
llm_kwargs["load_format"] = "bitsandbytes"
logger.info("Using 8-bit quantization")
else:
quantization = "float16" # Default to FP16
logger.info("Using default FP16 precision")
logger.info("Using FP16 precision")

# Get available GPU memory
gpu_memory = get_gpu_memory()
logger.info(f"Available GPU memory: {gpu_memory}")

# Initialize vLLM with more specific parameters
self.llm = LLM(
model=self.local_model_path,
quantization=quantization,
trust_remote_code=True,
dtype="float16",
tensor_parallel_size=len(gpu_memory), # Use all available GPUs
max_num_batched_tokens=max_batch_size, # Adjust based on your needs
max_num_seqs=max_num_seqs, # Adjust based on your needs
gpu_memory_utilization=mem_utilization, # Adjust GPU memory utilization
)
self.llm = LLM(**llm_kwargs)

logger.info(f"Model loaded: {self.model_id}")
logger.info(f"Using tensor parallelism across {len(gpu_memory)} GPUs")
logger.info(f"Using GPU memory utilization: {mem_utilization}")

async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
conversation = []
Expand All @@ -67,19 +77,28 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys
top_k=kwargs.get("top_k", -1),
)

request_id = 0
request = InferenceRequest(request_id, full_prompt, sampling_params)

total_tokens = 0
async for output in self.llm.generate_stream(request):
if output.outputs:
async for output in self.llm.generate(prompt=full_prompt, sampling_params=sampling_params, stream=True):
if isinstance(output, RequestOutput):
generated_text = output.outputs[0].text
total_tokens += len(generated_text)
yield generated_text
await asyncio.sleep(0) # Allow other tasks to run

input_length = len(self.llm.get_tokenizer().encode(full_prompt))
yield {"tokens_used": input_length + total_tokens}
# Get the final output to calculate total tokens
final_output = await self.llm.generate(prompt=full_prompt, sampling_params=sampling_params)
if isinstance(final_output, RequestOutput):
total_tokens = final_output.prompt_token_ids.shape[1] + len(final_output.outputs[0].token_ids)
yield {"tokens_used": total_tokens}

def __str__(self):
return f"LLMPipeline(model_id={self.model_id})"
def _find_model_path(self, base_path):
# Check if the model files are directly in the base path
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)):
return base_path

# If not, look in subdirectories
for root, dirs, files in os.walk(base_path):
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files):
return root

return None
1 change: 1 addition & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
is_numeric,
split_prompt,
validate_torch_device,
get_max_memory
)
20 changes: 19 additions & 1 deletion runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
import psutil

import numpy as np
import torch
Expand Down Expand Up @@ -37,7 +38,24 @@ def get_torch_device():
return torch.device("mps")
else:
return torch.device("cpu")


class MemoryInfo:
def __init__(self, gpu_memory, cpu_memory, num_gpus):
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.num_gpus = num_gpus

def __repr__(self):
return f"<MemoryInfo: GPUs={self.num_gpus}, CPU Memory={self.cpu_memory}, GPU Memory={self.gpu_memory}>"

def get_max_memory() -> MemoryInfo:
num_gpus = torch.cuda.device_count()
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"

memory_info = MemoryInfo(gpu_memory=gpu_memory, cpu_memory=cpu_memory, num_gpus=num_gpus)

return memory_info

def validate_torch_device(device_name: str) -> bool:
"""Checks if the given PyTorch device name is valid and available.
Expand Down
10 changes: 9 additions & 1 deletion runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ openapi: 3.1.0
info:
title: Livepeer AI Runner
description: An application to run AI pipelines
version: v0.7.0
version: ''
servers:
- url: https://dream-gateway.livepeer.cloud
description: Livepeer Cloud Community Gateway
Expand Down Expand Up @@ -541,6 +541,14 @@ components:
type: integer
title: Max Tokens
default: 256
top_p:
type: number
title: Top P
default: 1.0
top_k:
type: integer
title: Top K
default: -1
history:
type: string
title: History
Expand Down
10 changes: 9 additions & 1 deletion runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ openapi: 3.1.0
info:
title: Livepeer AI Runner
description: An application to run AI pipelines
version: v0.7.0
version: ''
servers:
- url: https://dream-gateway.livepeer.cloud
description: Livepeer Cloud Community Gateway
Expand Down Expand Up @@ -549,6 +549,14 @@ components:
type: integer
title: Max Tokens
default: 256
top_p:
type: number
title: Top P
default: 1.0
top_k:
type: integer
title: Top K
default: -1
history:
type: string
title: History
Expand Down
22 changes: 22 additions & 0 deletions runner/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
vllm==0.6.3
diffusers
accelerate
transformers
fastapi
pydantic
Pillow
python-multipart
uvicorn
huggingface_hub
xformers
triton
peft
deepcache
safetensors
scipy
numpy
av
sentencepiece
protobuf
bitsandbytes
psutil
Loading

0 comments on commit 43b919d

Please sign in to comment.