Skip to content

Commit

Permalink
feat: better LLM response format (#387)
Browse files Browse the repository at this point in the history
* feat: use uuid for vLLM request generation

* feat: better LLM response format
  • Loading branch information
kyriediculous authored Jan 14, 2025
1 parent 94e1054 commit 1ede01e
Show file tree
Hide file tree
Showing 8 changed files with 767 additions and 202 deletions.
67 changes: 43 additions & 24 deletions runner/app/pipelines/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import time
import uuid
from dataclasses import dataclass
from typing import Dict, Any, List, AsyncGenerator, Union, Optional

Expand All @@ -10,6 +11,7 @@
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
from huggingface_hub import file_download
from transformers import AutoConfig
from app.routes.utils import LLMResponse, LLMChoice, LLMMessage, LLMTokenUsage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,7 +196,7 @@ async def generate(
frequency_penalty=config.frequency_penalty,
)

request_id = f"chatcmpl-{int(time.time())}"
request_id = f"chatcmpl-{uuid.uuid4()}"

results_generator = self.engine.generate(
prompt=full_prompt, sampling_params=sampling_params, request_id=request_id)
Expand All @@ -219,15 +221,25 @@ async def generate(
current_response = generated_text
total_tokens += len(tokenizer.encode(delta))

yield {
"choices": [{
"delta": {"content": delta},
"finish_reason": None
}],
"created": int(time.time()),
"model": self.model_id,
"id": request_id
}
yield LLMResponse(
choices=[
LLMChoice(
delta=LLMMessage(
role="assistant",
content=delta
),
index=0
)
],
tokens_used=LLMTokenUsage(
prompt_tokens=input_tokens,
completion_tokens=total_tokens,
total_tokens=input_tokens + total_tokens
),
id=request_id,
model=self.model_id,
created=int(time.time())
)

await asyncio.sleep(0)

Expand All @@ -242,20 +254,27 @@ async def generate(
logger.info(f" Generated tokens: {total_tokens}")
generation_time = end_time - first_token_time if first_token_time else 0
logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}")
yield {
"choices": [{
"delta": {"content": ""},
"finish_reason": "stop"
}],
"created": int(time.time()),
"model": self.model_id,
"id": request_id,
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": total_tokens,
"total_tokens": input_tokens + total_tokens
}
}

yield LLMResponse(
choices=[
LLMChoice(
delta=LLMMessage(
role="assistant",
content=""
),
index=0,
finish_reason="stop"
)
],
tokens_used=LLMTokenUsage(
prompt_tokens=input_tokens,
completion_tokens=total_tokens,
total_tokens=input_tokens + total_tokens
),
id=request_id,
model=self.model_id,
created=int(time.time())
)

except Exception as e:
if "CUDA out of memory" in str(e):
Expand Down
62 changes: 35 additions & 27 deletions runner/app/routes/llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import os
import time
from typing import Union
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, LLMRequest, LLMResponse, http_error
from app.routes.utils import HTTPError, LLMRequest, LLMChoice, LLMMessage, LLMResponse, http_error
import json

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
Expand All @@ -20,23 +19,24 @@
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post(
"/llm",
response_model=LLMResponse,
response_model=LLMResponse
,
responses=RESPONSES,
operation_id="genLLM",
description="Generate text using a language model.",
summary="LLM",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "llm"},
)
@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False)
@router.post("/llm/", response_model=LLMResponse
, responses=RESPONSES, include_in_schema=False)
async def llm(
request: LLMRequest,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
) -> Union[LLMResponse, JSONResponse, StreamingResponse]:
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
Expand Down Expand Up @@ -71,24 +71,31 @@ async def llm(
else:
full_response = ""
last_chunk = None

async for chunk in generator:
if isinstance(chunk, dict):
if "choices" in chunk:
if "delta" in chunk["choices"][0]:
full_response += chunk["choices"][0]["delta"].get(
"content", "")
last_chunk = chunk
if chunk.choices and chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
last_chunk = chunk

usage = last_chunk.get("usage", {})

return LLMResponse(
response=full_response,
tokens_used=usage.get("total_tokens", 0),
id=last_chunk.get("id", ""),
model=last_chunk.get("model", pipeline.model_id),
created=last_chunk.get("created", int(time.time()))
)
if last_chunk:
# Return the final response with accumulated text
return LLMResponse(
choices=[
LLMChoice(
message=LLMMessage(
role="assistant",
content=full_response
),
index=0,
finish_reason="stop"
)
],
tokens_used=last_chunk.tokens_used,
id=last_chunk.id,
model=last_chunk.model,
created=last_chunk.created
)
else:
raise ValueError("No response generated")

except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
Expand All @@ -101,12 +108,13 @@ async def llm(
async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict):
if "choices" in chunk:
if isinstance(chunk, LLMResponse):
if len(chunk.choices) > 0:
# Regular streaming chunk or final chunk
yield f"data: {json.dumps(chunk)}\n\n"
if chunk["choices"][0].get("finish_reason") == "stop":
yield f"data: {chunk.model_dump_json()}\n\n"
if chunk.choices[0].finish_reason == "stop":
break
# Signal end of stream
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
Expand Down
43 changes: 35 additions & 8 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,41 @@ class LLMMessage(BaseModel):
content: str


class LLMBaseChoice(BaseModel):
index: int
finish_reason: str = "" # Needs OpenAPI 3.1 support to make optional


class LLMTokenUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int

class LLMChoice(LLMBaseChoice):
delta: LLMMessage = None
message: LLMMessage = None

class LLMResponse(BaseModel):
id: str
model: str
created: int
tokens_used: LLMTokenUsage
choices: List[LLMChoice]


# class LLMStreamChoice(LLMBaseChoice):
# delta: LLMMessage

# class LLMNonStreamChoice(LLMBaseChoice):
# message: LLMMessage

# class LLMStreamResponse(LLMBaseResponse):
# choices: List[LLMStreamChoice]

# class LLMNonStreamResponse(LLMBaseResponse):
# choices: List[LLMNonStreamChoice]


class LLMRequest(BaseModel):
messages: List[LLMMessage]
model: str = ""
Expand All @@ -87,14 +122,6 @@ class LLMRequest(BaseModel):
stream: bool = False


class LLMResponse(BaseModel):
response: str
tokens_used: int
id: str
model: str
created: int


class ImageToTextResponse(BaseModel):
"""Response model for text generation."""

Expand Down
57 changes: 45 additions & 12 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,7 @@ components:
AudioResponse:
properties:
audio:
allOf:
- $ref: '#/components/schemas/MediaURL'
$ref: '#/components/schemas/MediaURL'
description: The generated audio.
type: object
required:
Expand Down Expand Up @@ -827,8 +826,7 @@ components:
HTTPError:
properties:
detail:
allOf:
- $ref: '#/components/schemas/APIError'
$ref: '#/components/schemas/APIError'
description: Detailed error information.
type: object
required:
Expand Down Expand Up @@ -868,6 +866,23 @@ components:
- text
title: ImageToTextResponse
description: Response model for text generation.
LLMChoice:
properties:
index:
type: integer
title: Index
finish_reason:
type: string
title: Finish Reason
default: ''
delta:
$ref: '#/components/schemas/LLMMessage'
message:
$ref: '#/components/schemas/LLMMessage'
type: object
required:
- index
title: LLMChoice
LLMMessage:
properties:
role:
Expand Down Expand Up @@ -918,12 +933,6 @@ components:
title: LLMRequest
LLMResponse:
properties:
response:
type: string
title: Response
tokens_used:
type: integer
title: Tokens Used
id:
type: string
title: Id
Expand All @@ -933,14 +942,38 @@ components:
created:
type: integer
title: Created
tokens_used:
$ref: '#/components/schemas/LLMTokenUsage'
choices:
items:
$ref: '#/components/schemas/LLMChoice'
type: array
title: Choices
type: object
required:
- response
- tokens_used
- id
- model
- created
- tokens_used
- choices
title: LLMResponse
LLMTokenUsage:
properties:
prompt_tokens:
type: integer
title: Prompt Tokens
completion_tokens:
type: integer
title: Completion Tokens
total_tokens:
type: integer
title: Total Tokens
type: object
required:
- prompt_tokens
- completion_tokens
- total_tokens
title: LLMTokenUsage
LiveVideoToVideoParams:
properties:
subscribe_url:
Expand Down
Loading

0 comments on commit 1ede01e

Please sign in to comment.