-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm.py
127 lines (104 loc) · 3.49 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Optional, List, Mapping, Any
import os
import time
from modal import Image, Stub, gpu, enter, method, exit
MODEL_DIR = "/model"
BASE_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
GPU_CONFIG = gpu.A100(memory=80, count=2)
def download_model_to_folder():
from huggingface_hub import snapshot_download
from transformers.utils import move_cache
os.makedirs(MODEL_DIR, exist_ok=True)
snapshot_download(
BASE_MODEL,
local_dir=MODEL_DIR,
ignore_patterns=["*.pt"], # Using safetensors
)
move_cache()
vllm_image = (
Image.from_registry(
"nvidia/cuda:12.1.1-devel-ubuntu22.04", add_python="3.10"
)
.pip_install(
"vllm==0.3.2",
"huggingface_hub==0.19.4",
"hf-transfer==0.1.4",
"torch==2.1.2",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_model_to_folder, timeout=60 * 20)
)
stub = Stub(name="secondbrain")
@stub.cls(
gpu=GPU_CONFIG,
timeout=60 * 10,
container_idle_timeout=60 * 10,
allow_concurrent_inputs=10,
image=vllm_image,
)
class Model:
@enter()
def start_engine(self):
import time
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
print("🥶 cold starting inference")
start = time.monotonic_ns()
if GPU_CONFIG.count > 1:
# Patch issue from https://github.com/vllm-project/vllm/issues/1116
import ray
ray.shutdown()
ray.init(num_gpus=GPU_CONFIG.count)
engine_args = AsyncEngineArgs(
model=MODEL_DIR,
tensor_parallel_size=GPU_CONFIG.count,
gpu_memory_utilization=0.90,
enforce_eager=False, # capture the graph for faster inference, but slower cold starts
disable_log_stats=True, # disable logging so we can stream tokens
disable_log_requests=True,
)
self.template = "<s> [INST] {user} [/INST] "
# this can take some time!
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
duration_s = (time.monotonic_ns() - start) / 1e9
print(f"🏎️ engine started in {duration_s:.0f}s")
@method()
async def completion_stream(self, user_question):
from vllm import SamplingParams
from vllm.utils import random_uuid
sampling_params = SamplingParams(
temperature=0.75,
max_tokens=1024,
repetition_penalty=1.1,
)
request_id = random_uuid()
result_generator = self.engine.generate(
self.template.format(user=user_question),
sampling_params,
request_id,
)
index = 0
async for output in result_generator:
if (
output.outputs[0].text
and "\ufffd" == output.outputs[0].text[-1]
):
continue
text_delta = output.outputs[0].text[index:]
index = len(output.outputs[0].text)
yield text_delta
@exit()
def stop_engine(self, *args, **kwargs):
if GPU_CONFIG.count > 1:
import ray
ray.shutdown()
@stub.local_entrypoint()
def main():
model = Model()
questions = [
"What is the square root of 49"
]
for question in questions:
print("Sending new request:", question, "\n\n")
for text in model.completion_stream.remote_gen(question):
print(text, end="", flush=True)