Skip to content

Commit

Permalink
llm: use vLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Oct 16, 2024
1 parent 40fa0c2 commit 8a5596d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 161 deletions.
204 changes: 44 additions & 160 deletions runner/app/pipelines/llm.py
Original file line number Diff line number Diff line change
@@ -1,145 +1,50 @@
import asyncio
import logging
import os
import psutil
from typing import Dict, Any, List, Optional, AsyncGenerator, Union

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download, snapshot_download
from threading import Thread
from vllm import LLM, SamplingParams
from vllm.utils import InferenceRequest
from vllm.model_executor.parallel_utils import get_gpu_memory

logger = logging.getLogger(__name__)


def get_max_memory():
num_gpus = torch.cuda.device_count()
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"
max_memory = {**gpu_memory, "cpu": cpu_memory}

logger.info(f"Max memory configuration: {max_memory}")
return max_memory


def load_model_8bit(model_id: str, **kwargs):
max_memory = get_max_memory()

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
max_memory=max_memory,
offload_folder="offload",
low_cpu_mem_usage=True,
**kwargs
)

return tokenizer, model


def load_model_fp16(model_id: str, **kwargs):
device = get_torch_device()
max_memory = get_max_memory()

# Check for fp16 variant
local_model_path = os.path.join(
get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
has_fp16_variant = any(".fp16.safetensors" in fname for _, _,
files in os.walk(local_model_path) for fname in files)

if device != "cpu" and has_fp16_variant:
logger.info("Loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

checkpoint_dir = snapshot_download(
model_id, cache_dir=get_model_dir(), local_files_only=True)

model = load_checkpoint_and_dispatch(
model,
checkpoint_dir,
device_map="auto",
max_memory=max_memory,
# Adjust based on your model architecture
no_split_module_classes=["LlamaDecoderLayer"],
dtype=kwargs.get("torch_dtype", torch.float32),
offload_folder="offload",
offload_state_dict=True,
)

return tokenizer, model


class LLMPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {
"cache_dir": get_model_dir(),
"local_files_only": True,
}
self.device = get_torch_device()

# Generate the correct folder name
folder_path = file_download.repo_folder_name(
repo_id=model_id, repo_type="model")
self.local_model_path = os.path.join(get_model_dir(), folder_path)
self.checkpoint_dir = snapshot_download(
model_id, cache_dir=get_model_dir(), local_files_only=True)

logger.info(f"Local model path: {self.local_model_path}")
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}")
self.local_model_path = os.path.join(get_model_dir(), model_id)

use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true"
max_batch_size = os.getenv("MAX_BATCH_SIZE", "4096")
max_num_seqs = os.getenv("MAX_NUM_SEQS", "256")
mem_utilization = os.getenv("GPU_MEMORY_UTILIZATION", "0.90")

if use_8bit:
quantization = "int8"
logger.info("Using 8-bit quantization")
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs)
else:
logger.info("Using fp16/bf16 precision")
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs)

logger.info(
f"Model loaded and distributed. Device map: {self.model.hf_device_map}"
quantization = "float16" # Default to FP16
logger.info("Using default FP16 precision")

# Get available GPU memory
gpu_memory = get_gpu_memory()
logger.info(f"Available GPU memory: {gpu_memory}")

# Initialize vLLM with more specific parameters
self.llm = LLM(
model=self.local_model_path,
quantization=quantization,
trust_remote_code=True,
dtype="float16",
tensor_parallel_size=len(gpu_memory), # Use all available GPUs
max_num_batched_tokens=max_batch_size, # Adjust based on your needs
max_num_seqs=max_num_seqs, # Adjust based on your needs
gpu_memory_utilization=mem_utilization, # Adjust GPU memory utilization
)

# Set up generation config
self.generation_config = self.model.generation_config

self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Optional: Add optimizations
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"LLMPipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)
logger.info(f"Model loaded: {self.model_id}")
logger.info(f"Using tensor parallelism across {len(gpu_memory)} GPUs")

async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
conversation = []
Expand All @@ -149,53 +54,32 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys
conversation.extend(history)
conversation.append({"role": "user", "content": prompt})

input_ids = self.tokenizer.apply_chat_template(
conversation, return_tensors="pt").to(self.model.device)
attention_mask = torch.ones_like(input_ids)
# Apply chat template
full_prompt = self.llm.get_tokenizer().apply_chat_template(conversation, tokenize=False)

max_new_tokens = kwargs.get("max_tokens", 256)
max_tokens = kwargs.get("max_tokens", 256)
temperature = kwargs.get("temperature", 0.7)

streamer = TextIteratorStreamer(
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

generate_kwargs = self.generation_config.to_dict()
generate_kwargs.update({
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.eos_token_id,
})
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
top_p=kwargs.get("top_p", 1.0),
top_k=kwargs.get("top_k", -1),
)

thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs)
thread.start()
request_id = 0
request = InferenceRequest(request_id, full_prompt, sampling_params)

total_tokens = 0
try:
for text in streamer:
total_tokens += 1
yield text
async for output in self.llm.generate_stream(request):
if output.outputs:
generated_text = output.outputs[0].text
total_tokens += len(generated_text)
yield generated_text
await asyncio.sleep(0) # Allow other tasks to run
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise

input_length = input_ids.size(1)
input_length = len(self.llm.get_tokenizer().encode(full_prompt))
yield {"tokens_used": input_length + total_tokens}

def model_generate_wrapper(self, **kwargs):
try:
logger.debug("Entering model.generate")
with torch.cuda.amp.autocast(): # Use automatic mixed precision
self.model.generate(**kwargs)
logger.debug("Exiting model.generate")
except Exception as e:
logger.error(f"Error in model.generate: {str(e)}", exc_info=True)
raise

def __str__(self):
return f"LLMPipeline(model_id={self.model_id})"
6 changes: 5 additions & 1 deletion runner/app/routes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ async def llm(
system_msg: Annotated[str, Form()] = "",
temperature: Annotated[float, Form()] = 0.7,
max_tokens: Annotated[int, Form()] = 256,
top_p: Annotated[float, Form()] = 1.0,
top_k: Annotated[int, Form()] = -1,
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON
stream: Annotated[bool, Form()] = False,
pipeline: Pipeline = Depends(get_pipeline),
Expand Down Expand Up @@ -71,7 +73,9 @@ async def llm(
history=history_list,
system_msg=system_msg if system_msg else None,
temperature=temperature,
max_tokens=max_tokens
max_tokens=max_tokens,
top_p=top_p,
top_k=top_k
)

if stream:
Expand Down
1 change: 1 addition & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ sentencepiece== 0.2.0
protobuf==5.27.2
bitsandbytes==0.43.3
psutil==6.0.0
vllm==0.6.3

0 comments on commit 8a5596d

Please sign in to comment.