Skip to content

Commit 5ad803b

Browse files
authored
Create openai_api.py
add openai_api for qwen_cpp
1 parent 9648bf6 commit 5ad803b

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

qwen_cpp/openai_api.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import asyncio
2+
import logging
3+
import time
4+
from typing import List, Literal, Optional, Union
5+
6+
import qwen_cpp
7+
from fastapi import FastAPI, HTTPException, Request, status
8+
from fastapi.middleware.cors import CORSMiddleware
9+
from pydantic import BaseModel, Field
10+
from pydantic_settings import BaseSettings
11+
from sse_starlette.sse import EventSourceResponse
12+
13+
logging.basicConfig(level=logging.INFO, format=r"%(asctime)s - %(module)s - %(levelname)s - %(message)s")
14+
15+
16+
class Settings(BaseSettings):
17+
model: str = "qwen14b-ggml.bin"
18+
tiktoken: str = "Qwen-14B-Chat/qwen.tiktoken"
19+
num_threads: int = 0
20+
21+
22+
class ChatMessage(BaseModel):
23+
role: Literal["system", "user", "assistant"]
24+
content: str
25+
26+
27+
class DeltaMessage(BaseModel):
28+
role: Optional[Literal["system", "user", "assistant"]] = None
29+
content: Optional[str] = None
30+
31+
32+
class ChatCompletionRequest(BaseModel):
33+
model: str = "default-model"
34+
messages: List[ChatMessage]
35+
temperature: float = Field(default=0.95, ge=0.0, le=2.0)
36+
top_p: float = Field(default=0.7, ge=0.0, le=1.0)
37+
stream: bool = False
38+
max_tokens: int = Field(default=2048, ge=0)
39+
40+
model_config = {
41+
"json_schema_extra": {"examples": [{"model": "default-model", "messages": [{"role": "user", "content": "你好"}]}]}
42+
}
43+
44+
45+
class ChatCompletionResponseChoice(BaseModel):
46+
index: int = 0
47+
message: ChatMessage
48+
finish_reason: Literal["stop", "length"] = "stop"
49+
50+
51+
class ChatCompletionResponseStreamChoice(BaseModel):
52+
index: int = 0
53+
delta: DeltaMessage
54+
finish_reason: Optional[Literal["stop", "length"]] = None
55+
56+
57+
class ChatCompletionResponse(BaseModel):
58+
id: str = "chatcmpl"
59+
model: str = "default-model"
60+
object: Literal["chat.completion", "chat.completion.chunk"]
61+
created: int = Field(default_factory=lambda: int(time.time()))
62+
choices: Union[List[ChatCompletionResponseChoice], List[ChatCompletionResponseStreamChoice]]
63+
64+
model_config = {
65+
"json_schema_extra": {
66+
"examples": [
67+
{
68+
"id": "chatcmpl",
69+
"model": "default-model",
70+
"object": "chat.completion",
71+
"created": 1691166146,
72+
"choices": [
73+
{
74+
"index": 0,
75+
"message": {"role": "assistant", "content": "你好!"},
76+
"finish_reason": "stop",
77+
}
78+
],
79+
}
80+
]
81+
}
82+
}
83+
84+
85+
settings = Settings()
86+
app = FastAPI()
87+
app.add_middleware(
88+
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
89+
)
90+
pipeline = qwen_cpp.Pipeline(settings.model,settings.tiktoken)
91+
lock = asyncio.Lock()
92+
93+
94+
def stream_chat(history, body):
95+
yield ChatCompletionResponse(
96+
object="chat.completion.chunk",
97+
choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(role="assistant"))],
98+
)
99+
100+
for piece in pipeline.chat(
101+
history,
102+
max_length=body.max_tokens,
103+
do_sample=body.temperature > 0,
104+
top_p=body.top_p,
105+
temperature=body.temperature,
106+
num_threads=settings.num_threads,
107+
stream=True,
108+
):
109+
yield ChatCompletionResponse(
110+
object="chat.completion.chunk",
111+
choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(content=piece))],
112+
)
113+
114+
yield ChatCompletionResponse(
115+
object="chat.completion.chunk",
116+
choices=[ChatCompletionResponseStreamChoice(delta=DeltaMessage(), finish_reason="stop")],
117+
)
118+
119+
120+
async def stream_chat_event_publisher(history, body):
121+
output = ""
122+
try:
123+
async with lock:
124+
for chunk in stream_chat(history, body):
125+
await asyncio.sleep(0) # yield control back to event loop for cancellation check
126+
output += chunk.choices[0].delta.content or ""
127+
yield chunk.model_dump_json(exclude_unset=True)
128+
logging.info(f'prompt: "{history[-1]}", stream response: "{output}"')
129+
except asyncio.CancelledError as e:
130+
logging.info(f'prompt: "{history[-1]}", stream response (partial): "{output}"')
131+
raise e
132+
133+
134+
@app.post("/v1/chat/completions")
135+
async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionResponse:
136+
# ignore system messages
137+
history = [msg.content for msg in body.messages if msg.role != "system"]
138+
if len(history) % 2 != 1:
139+
raise HTTPException(status.HTTP_400_BAD_REQUEST, "invalid history size")
140+
141+
if body.stream:
142+
generator = stream_chat_event_publisher(history, body)
143+
return EventSourceResponse(generator)
144+
145+
output = pipeline.chat(
146+
history=history,
147+
max_length=body.max_tokens,
148+
do_sample=body.temperature > 0,
149+
top_p=body.top_p,
150+
temperature=body.temperature,
151+
)
152+
logging.info(f'prompt: "{history[-1]}", sync response: "{output}"')
153+
154+
return ChatCompletionResponse(
155+
object="chat.completion",
156+
choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output))],
157+
)
158+
159+
160+
class ModelCard(BaseModel):
161+
id: str
162+
object: Literal["model"] = "model"
163+
owned_by: str = "owner"
164+
permission: List = []
165+
166+
167+
class ModelList(BaseModel):
168+
object: Literal["list"] = "list"
169+
data: List[ModelCard] = []
170+
171+
model_config = {
172+
"json_schema_extra": {
173+
"examples": [
174+
{
175+
"object": "list",
176+
"data": [{"id": "gpt-3.5-turbo", "object": "model", "owned_by": "owner", "permission": []}],
177+
}
178+
]
179+
}
180+
}
181+
182+
183+
@app.get("/v1/models")
184+
async def list_models() -> ModelList:
185+
return ModelList(data=[ModelCard(id="gpt-3.5-turbo")])

0 commit comments

Comments
 (0)