-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
40fa0c2
commit 261392f
Showing
10 changed files
with
668 additions
and
243 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,201 +1,142 @@ | ||
import asyncio | ||
import logging | ||
import os | ||
import psutil | ||
import time | ||
from typing import Dict, Any, List, Optional, AsyncGenerator, Union | ||
|
||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig | ||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_model_dir, get_torch_device | ||
from huggingface_hub import file_download, snapshot_download | ||
from threading import Thread | ||
from app.pipelines.utils import get_model_dir, get_max_memory | ||
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams | ||
from vllm.outputs import RequestOutput | ||
from huggingface_hub import file_download | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_max_memory(): | ||
num_gpus = torch.cuda.device_count() | ||
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)} | ||
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB" | ||
max_memory = {**gpu_memory, "cpu": cpu_memory} | ||
|
||
logger.info(f"Max memory configuration: {max_memory}") | ||
return max_memory | ||
|
||
|
||
def load_model_8bit(model_id: str, **kwargs): | ||
max_memory = get_max_memory() | ||
|
||
quantization_config = BitsAndBytesConfig( | ||
load_in_8bit=True, | ||
llm_int8_threshold=6.0, | ||
llm_int8_has_fp16_weight=False, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
quantization_config=quantization_config, | ||
device_map="auto", | ||
max_memory=max_memory, | ||
offload_folder="offload", | ||
low_cpu_mem_usage=True, | ||
**kwargs | ||
) | ||
|
||
return tokenizer, model | ||
|
||
|
||
def load_model_fp16(model_id: str, **kwargs): | ||
device = get_torch_device() | ||
max_memory = get_max_memory() | ||
|
||
# Check for fp16 variant | ||
local_model_path = os.path.join( | ||
get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model")) | ||
has_fp16_variant = any(".fp16.safetensors" in fname for _, _, | ||
files in os.walk(local_model_path) for fname in files) | ||
|
||
if device != "cpu" and has_fp16_variant: | ||
logger.info("Loading fp16 variant for %s", model_id) | ||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
elif device != "cpu": | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
|
||
config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config | ||
|
||
with init_empty_weights(): | ||
model = AutoModelForCausalLM.from_config(config) | ||
|
||
checkpoint_dir = snapshot_download( | ||
model_id, cache_dir=get_model_dir(), local_files_only=True) | ||
|
||
model = load_checkpoint_and_dispatch( | ||
model, | ||
checkpoint_dir, | ||
device_map="auto", | ||
max_memory=max_memory, | ||
# Adjust based on your model architecture | ||
no_split_module_classes=["LlamaDecoderLayer"], | ||
dtype=kwargs.get("torch_dtype", torch.float32), | ||
offload_folder="offload", | ||
offload_state_dict=True, | ||
) | ||
|
||
return tokenizer, model | ||
|
||
|
||
class LLMPipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
kwargs = { | ||
"cache_dir": get_model_dir(), | ||
"local_files_only": True, | ||
} | ||
self.device = get_torch_device() | ||
|
||
# Generate the correct folder name | ||
folder_path = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model") | ||
self.local_model_path = os.path.join(get_model_dir(), folder_path) | ||
self.checkpoint_dir = snapshot_download( | ||
model_id, cache_dir=get_model_dir(), local_files_only=True) | ||
|
||
logger.info(f"Local model path: {self.local_model_path}") | ||
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}") | ||
folder_name = file_download.repo_folder_name(repo_id=model_id, repo_type="model") | ||
base_path = os.path.join(get_model_dir(), folder_name) | ||
|
||
# Find the actual model path | ||
self.local_model_path = self._find_model_path(base_path) | ||
|
||
if not self.local_model_path: | ||
raise ValueError(f"Could not find model files for {model_id}") | ||
|
||
use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true" | ||
max_batch_size = int(os.getenv("MAX_BATCH_SIZE", "4096")) | ||
max_num_seqs = int(os.getenv("MAX_NUM_SEQS", "128")) | ||
mem_utilization = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.80")) | ||
|
||
# Get available GPU memory | ||
max_memory = get_max_memory() | ||
logger.info(f"Available GPU memory: {max_memory.gpu_memory}") | ||
|
||
engine_args = AsyncEngineArgs( | ||
model=self.local_model_path, | ||
tokenizer=self.local_model_path, | ||
trust_remote_code=True, | ||
dtype="auto", # This specifies BFloat16 precision, TODO: Check GPU capabilities to set best type | ||
kv_cache_dtype="auto", # or "fp16" if you want to force it | ||
tensor_parallel_size=max_memory.num_gpus, | ||
max_num_batched_tokens=max_batch_size, | ||
gpu_memory_utilization=mem_utilization, | ||
max_num_seqs=max_num_seqs, | ||
enforce_eager=False, | ||
enable_prefix_caching=True, | ||
seed=42, | ||
swap_space_bytes=4 * 1024 * 1024 * 1024, # 4 GiB instead of 8 GiB | ||
) | ||
|
||
if use_8bit: | ||
engine_args.quantization = "bitsandbytes" | ||
logger.info("Using 8-bit quantization") | ||
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs) | ||
else: | ||
logger.info("Using fp16/bf16 precision") | ||
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs) | ||
logger.info("Using BFloat16 precision") | ||
|
||
logger.info( | ||
f"Model loaded and distributed. Device map: {self.model.hf_device_map}" | ||
) | ||
self.engine = AsyncLLMEngine.from_engine_args(engine_args) | ||
|
||
# Set up generation config | ||
self.generation_config = self.model.generation_config | ||
|
||
self.terminators = [ | ||
self.tokenizer.eos_token_id, | ||
self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | ||
] | ||
|
||
# Optional: Add optimizations | ||
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" | ||
if sfast_enabled: | ||
logger.info( | ||
"LLMPipeline will be dynamically compiled with stable-fast for %s", | ||
model_id, | ||
) | ||
from app.pipelines.optim.sfast import compile_model | ||
self.model = compile_model(self.model) | ||
logger.info(f"Model loaded: {self.model_id}") | ||
logger.info(f"Using GPU memory utilization: {mem_utilization}") | ||
self.engine.start_background_loop() | ||
|
||
async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: | ||
start_time = time.time() | ||
|
||
conversation = [] | ||
if system_msg: | ||
conversation.append({"role": "system", "content": system_msg}) | ||
if history: | ||
conversation.extend(history) | ||
for user_msg, assistant_msg in history: | ||
conversation.append({"role": "user", "content": user_msg}) | ||
if assistant_msg: | ||
conversation.append({"role": "assistant", "content": assistant_msg}) | ||
conversation.append({"role": "user", "content": prompt}) | ||
|
||
input_ids = self.tokenizer.apply_chat_template( | ||
conversation, return_tensors="pt").to(self.model.device) | ||
attention_mask = torch.ones_like(input_ids) | ||
tokenizer = await self.engine.get_tokenizer() | ||
full_prompt = tokenizer.apply_chat_template(conversation, tokenize=False) | ||
|
||
max_new_tokens = kwargs.get("max_tokens", 256) | ||
temperature = kwargs.get("temperature", 0.7) | ||
|
||
streamer = TextIteratorStreamer( | ||
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | ||
sampling_params = SamplingParams( | ||
temperature=kwargs.get("temperature", 0.7), | ||
max_tokens=kwargs.get("max_tokens", 256), | ||
top_p=kwargs.get("top_p", 1.0), | ||
top_k=kwargs.get("top_k", -1), | ||
) | ||
|
||
generate_kwargs = self.generation_config.to_dict() | ||
generate_kwargs.update({ | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"streamer": streamer, | ||
"max_new_tokens": max_new_tokens, | ||
"do_sample": temperature > 0, | ||
"temperature": temperature, | ||
"eos_token_id": self.tokenizer.eos_token_id, | ||
"pad_token_id": self.tokenizer.eos_token_id, | ||
}) | ||
request_id = str(time.monotonic()) | ||
results_generator = self.engine.generate(prompt=full_prompt, sampling_params=sampling_params, request_id=request_id) | ||
|
||
thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) | ||
thread.start() | ||
generated_tokens = 0 | ||
first_token_time = None | ||
previous_text = "" | ||
|
||
total_tokens = 0 | ||
try: | ||
for text in streamer: | ||
total_tokens += 1 | ||
yield text | ||
async for request_output in results_generator: | ||
if first_token_time is None: | ||
first_token_time = time.time() | ||
|
||
text = request_output.outputs[0].text | ||
new_text = text[len(previous_text):] | ||
generated_tokens += len(tokenizer.encode(new_text)) | ||
|
||
yield new_text | ||
previous_text = text | ||
await asyncio.sleep(0) # Allow other tasks to run | ||
except Exception as e: | ||
logger.error(f"Error during streaming: {str(e)}") | ||
logger.error(f"Error during generation: {e}") | ||
raise | ||
|
||
input_length = input_ids.size(1) | ||
yield {"tokens_used": input_length + total_tokens} | ||
end_time = time.time() | ||
|
||
def model_generate_wrapper(self, **kwargs): | ||
try: | ||
logger.debug("Entering model.generate") | ||
with torch.cuda.amp.autocast(): # Use automatic mixed precision | ||
self.model.generate(**kwargs) | ||
logger.debug("Exiting model.generate") | ||
except Exception as e: | ||
logger.error(f"Error in model.generate: {str(e)}", exc_info=True) | ||
raise | ||
# Calculate total tokens and timing | ||
prompt_tokens = len(tokenizer.encode(full_prompt)) | ||
total_tokens = prompt_tokens + generated_tokens | ||
total_time = end_time - start_time | ||
generation_time = end_time - first_token_time if first_token_time else 0 | ||
|
||
# Log benchmarking information | ||
logger.info(f"Generation completed:") | ||
logger.info(f" Total tokens: {total_tokens}") | ||
logger.info(f" Prompt tokens: {prompt_tokens}") | ||
logger.info(f" Generated tokens: {generated_tokens}") | ||
logger.info(f" Total time: {total_time:.2f} seconds") | ||
logger.info(f" Time to first token: {(first_token_time - start_time):.2f} seconds") | ||
logger.info(f" Generation time: {generation_time:.2f} seconds") | ||
logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}") | ||
|
||
yield {"tokens_used": total_tokens} | ||
|
||
def __str__(self): | ||
return f"LLMPipeline(model_id={self.model_id})" | ||
|
||
def _find_model_path(self, base_path): | ||
# Check if the model files are directly in the base path | ||
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in os.listdir(base_path)): | ||
return base_path | ||
|
||
# If not, look in subdirectories | ||
for root, dirs, files in os.walk(base_path): | ||
if any(file.endswith('.bin') or file.endswith('.safetensors') for file in files): | ||
return root | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,5 @@ | |
is_numeric, | ||
split_prompt, | ||
validate_torch_device, | ||
get_max_memory | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.