Skip to content

Commit 1ede01e

Browse files
feat: better LLM response format (#387)
* feat: use uuid for vLLM request generation * feat: better LLM response format
1 parent 94e1054 commit 1ede01e

File tree

8 files changed

+767
-202
lines changed

8 files changed

+767
-202
lines changed

runner/app/pipelines/llm.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import time
5+
import uuid
56
from dataclasses import dataclass
67
from typing import Dict, Any, List, AsyncGenerator, Union, Optional
78

@@ -10,6 +11,7 @@
1011
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
1112
from huggingface_hub import file_download
1213
from transformers import AutoConfig
14+
from app.routes.utils import LLMResponse, LLMChoice, LLMMessage, LLMTokenUsage
1315

1416
logger = logging.getLogger(__name__)
1517

@@ -194,7 +196,7 @@ async def generate(
194196
frequency_penalty=config.frequency_penalty,
195197
)
196198

197-
request_id = f"chatcmpl-{int(time.time())}"
199+
request_id = f"chatcmpl-{uuid.uuid4()}"
198200

199201
results_generator = self.engine.generate(
200202
prompt=full_prompt, sampling_params=sampling_params, request_id=request_id)
@@ -219,15 +221,25 @@ async def generate(
219221
current_response = generated_text
220222
total_tokens += len(tokenizer.encode(delta))
221223

222-
yield {
223-
"choices": [{
224-
"delta": {"content": delta},
225-
"finish_reason": None
226-
}],
227-
"created": int(time.time()),
228-
"model": self.model_id,
229-
"id": request_id
230-
}
224+
yield LLMResponse(
225+
choices=[
226+
LLMChoice(
227+
delta=LLMMessage(
228+
role="assistant",
229+
content=delta
230+
),
231+
index=0
232+
)
233+
],
234+
tokens_used=LLMTokenUsage(
235+
prompt_tokens=input_tokens,
236+
completion_tokens=total_tokens,
237+
total_tokens=input_tokens + total_tokens
238+
),
239+
id=request_id,
240+
model=self.model_id,
241+
created=int(time.time())
242+
)
231243

232244
await asyncio.sleep(0)
233245

@@ -242,20 +254,27 @@ async def generate(
242254
logger.info(f" Generated tokens: {total_tokens}")
243255
generation_time = end_time - first_token_time if first_token_time else 0
244256
logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}")
245-
yield {
246-
"choices": [{
247-
"delta": {"content": ""},
248-
"finish_reason": "stop"
249-
}],
250-
"created": int(time.time()),
251-
"model": self.model_id,
252-
"id": request_id,
253-
"usage": {
254-
"prompt_tokens": input_tokens,
255-
"completion_tokens": total_tokens,
256-
"total_tokens": input_tokens + total_tokens
257-
}
258-
}
257+
258+
yield LLMResponse(
259+
choices=[
260+
LLMChoice(
261+
delta=LLMMessage(
262+
role="assistant",
263+
content=""
264+
),
265+
index=0,
266+
finish_reason="stop"
267+
)
268+
],
269+
tokens_used=LLMTokenUsage(
270+
prompt_tokens=input_tokens,
271+
completion_tokens=total_tokens,
272+
total_tokens=input_tokens + total_tokens
273+
),
274+
id=request_id,
275+
model=self.model_id,
276+
created=int(time.time())
277+
)
259278

260279
except Exception as e:
261280
if "CUDA out of memory" in str(e):

runner/app/routes/llm.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import logging
22
import os
3-
import time
3+
from typing import Union
44
from fastapi import APIRouter, Depends, status
55
from fastapi.responses import JSONResponse, StreamingResponse
66
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
77
from app.dependencies import get_pipeline
88
from app.pipelines.base import Pipeline
9-
from app.routes.utils import HTTPError, LLMRequest, LLMResponse, http_error
9+
from app.routes.utils import HTTPError, LLMRequest, LLMChoice, LLMMessage, LLMResponse, http_error
1010
import json
1111

1212
router = APIRouter()
13-
1413
logger = logging.getLogger(__name__)
1514

1615
RESPONSES = {
@@ -20,23 +19,24 @@
2019
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
2120
}
2221

23-
2422
@router.post(
2523
"/llm",
26-
response_model=LLMResponse,
24+
response_model=LLMResponse
25+
,
2726
responses=RESPONSES,
2827
operation_id="genLLM",
2928
description="Generate text using a language model.",
3029
summary="LLM",
3130
tags=["generate"],
3231
openapi_extra={"x-speakeasy-name-override": "llm"},
3332
)
34-
@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False)
33+
@router.post("/llm/", response_model=LLMResponse
34+
, responses=RESPONSES, include_in_schema=False)
3535
async def llm(
3636
request: LLMRequest,
3737
pipeline: Pipeline = Depends(get_pipeline),
3838
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
39-
):
39+
) -> Union[LLMResponse, JSONResponse, StreamingResponse]:
4040
auth_token = os.environ.get("AUTH_TOKEN")
4141
if auth_token:
4242
if not token or token.credentials != auth_token:
@@ -71,24 +71,31 @@ async def llm(
7171
else:
7272
full_response = ""
7373
last_chunk = None
74-
7574
async for chunk in generator:
76-
if isinstance(chunk, dict):
77-
if "choices" in chunk:
78-
if "delta" in chunk["choices"][0]:
79-
full_response += chunk["choices"][0]["delta"].get(
80-
"content", "")
81-
last_chunk = chunk
75+
if chunk.choices and chunk.choices[0].delta.content:
76+
full_response += chunk.choices[0].delta.content
77+
last_chunk = chunk
8278

83-
usage = last_chunk.get("usage", {})
84-
85-
return LLMResponse(
86-
response=full_response,
87-
tokens_used=usage.get("total_tokens", 0),
88-
id=last_chunk.get("id", ""),
89-
model=last_chunk.get("model", pipeline.model_id),
90-
created=last_chunk.get("created", int(time.time()))
91-
)
79+
if last_chunk:
80+
# Return the final response with accumulated text
81+
return LLMResponse(
82+
choices=[
83+
LLMChoice(
84+
message=LLMMessage(
85+
role="assistant",
86+
content=full_response
87+
),
88+
index=0,
89+
finish_reason="stop"
90+
)
91+
],
92+
tokens_used=last_chunk.tokens_used,
93+
id=last_chunk.id,
94+
model=last_chunk.model,
95+
created=last_chunk.created
96+
)
97+
else:
98+
raise ValueError("No response generated")
9299

93100
except Exception as e:
94101
logger.error(f"LLM processing error: {str(e)}")
@@ -101,12 +108,13 @@ async def llm(
101108
async def stream_generator(generator):
102109
try:
103110
async for chunk in generator:
104-
if isinstance(chunk, dict):
105-
if "choices" in chunk:
111+
if isinstance(chunk, LLMResponse):
112+
if len(chunk.choices) > 0:
106113
# Regular streaming chunk or final chunk
107-
yield f"data: {json.dumps(chunk)}\n\n"
108-
if chunk["choices"][0].get("finish_reason") == "stop":
114+
yield f"data: {chunk.model_dump_json()}\n\n"
115+
if chunk.choices[0].finish_reason == "stop":
109116
break
117+
# Signal end of stream
110118
yield "data: [DONE]\n\n"
111119
except Exception as e:
112120
logger.error(f"Streaming error: {str(e)}")

runner/app/routes/utils.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,41 @@ class LLMMessage(BaseModel):
7777
content: str
7878

7979

80+
class LLMBaseChoice(BaseModel):
81+
index: int
82+
finish_reason: str = "" # Needs OpenAPI 3.1 support to make optional
83+
84+
85+
class LLMTokenUsage(BaseModel):
86+
prompt_tokens: int
87+
completion_tokens: int
88+
total_tokens: int
89+
90+
class LLMChoice(LLMBaseChoice):
91+
delta: LLMMessage = None
92+
message: LLMMessage = None
93+
94+
class LLMResponse(BaseModel):
95+
id: str
96+
model: str
97+
created: int
98+
tokens_used: LLMTokenUsage
99+
choices: List[LLMChoice]
100+
101+
102+
# class LLMStreamChoice(LLMBaseChoice):
103+
# delta: LLMMessage
104+
105+
# class LLMNonStreamChoice(LLMBaseChoice):
106+
# message: LLMMessage
107+
108+
# class LLMStreamResponse(LLMBaseResponse):
109+
# choices: List[LLMStreamChoice]
110+
111+
# class LLMNonStreamResponse(LLMBaseResponse):
112+
# choices: List[LLMNonStreamChoice]
113+
114+
80115
class LLMRequest(BaseModel):
81116
messages: List[LLMMessage]
82117
model: str = ""
@@ -87,14 +122,6 @@ class LLMRequest(BaseModel):
87122
stream: bool = False
88123

89124

90-
class LLMResponse(BaseModel):
91-
response: str
92-
tokens_used: int
93-
id: str
94-
model: str
95-
created: int
96-
97-
98125
class ImageToTextResponse(BaseModel):
99126
"""Response model for text generation."""
100127

runner/gateway.openapi.yaml

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,7 @@ components:
525525
AudioResponse:
526526
properties:
527527
audio:
528-
allOf:
529-
- $ref: '#/components/schemas/MediaURL'
528+
$ref: '#/components/schemas/MediaURL'
530529
description: The generated audio.
531530
type: object
532531
required:
@@ -827,8 +826,7 @@ components:
827826
HTTPError:
828827
properties:
829828
detail:
830-
allOf:
831-
- $ref: '#/components/schemas/APIError'
829+
$ref: '#/components/schemas/APIError'
832830
description: Detailed error information.
833831
type: object
834832
required:
@@ -868,6 +866,23 @@ components:
868866
- text
869867
title: ImageToTextResponse
870868
description: Response model for text generation.
869+
LLMChoice:
870+
properties:
871+
index:
872+
type: integer
873+
title: Index
874+
finish_reason:
875+
type: string
876+
title: Finish Reason
877+
default: ''
878+
delta:
879+
$ref: '#/components/schemas/LLMMessage'
880+
message:
881+
$ref: '#/components/schemas/LLMMessage'
882+
type: object
883+
required:
884+
- index
885+
title: LLMChoice
871886
LLMMessage:
872887
properties:
873888
role:
@@ -918,12 +933,6 @@ components:
918933
title: LLMRequest
919934
LLMResponse:
920935
properties:
921-
response:
922-
type: string
923-
title: Response
924-
tokens_used:
925-
type: integer
926-
title: Tokens Used
927936
id:
928937
type: string
929938
title: Id
@@ -933,14 +942,38 @@ components:
933942
created:
934943
type: integer
935944
title: Created
945+
tokens_used:
946+
$ref: '#/components/schemas/LLMTokenUsage'
947+
choices:
948+
items:
949+
$ref: '#/components/schemas/LLMChoice'
950+
type: array
951+
title: Choices
936952
type: object
937953
required:
938-
- response
939-
- tokens_used
940954
- id
941955
- model
942956
- created
957+
- tokens_used
958+
- choices
943959
title: LLMResponse
960+
LLMTokenUsage:
961+
properties:
962+
prompt_tokens:
963+
type: integer
964+
title: Prompt Tokens
965+
completion_tokens:
966+
type: integer
967+
title: Completion Tokens
968+
total_tokens:
969+
type: integer
970+
title: Total Tokens
971+
type: object
972+
required:
973+
- prompt_tokens
974+
- completion_tokens
975+
- total_tokens
976+
title: LLMTokenUsage
944977
LiveVideoToVideoParams:
945978
properties:
946979
subscribe_url:

0 commit comments

Comments
 (0)