Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better LLM response format #387

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions runner/app/pipelines/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import time
import uuid
from dataclasses import dataclass
from typing import Dict, Any, List, AsyncGenerator, Union, Optional

Expand All @@ -10,6 +11,7 @@
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
from huggingface_hub import file_download
from transformers import AutoConfig
from app.routes.utils import LLMResponse, LLMChoice, LLMMessage, LLMTokenUsage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,7 +196,7 @@ async def generate(
frequency_penalty=config.frequency_penalty,
)

request_id = f"chatcmpl-{int(time.time())}"
request_id = f"chatcmpl-{uuid.uuid4()}"

results_generator = self.engine.generate(
prompt=full_prompt, sampling_params=sampling_params, request_id=request_id)
Expand All @@ -219,15 +221,25 @@ async def generate(
current_response = generated_text
total_tokens += len(tokenizer.encode(delta))

yield {
"choices": [{
"delta": {"content": delta},
"finish_reason": None
}],
"created": int(time.time()),
"model": self.model_id,
"id": request_id
}
yield LLMResponse(
choices=[
LLMChoice(
delta=LLMMessage(
role="assistant",
content=delta
),
index=0
)
],
tokens_used=LLMTokenUsage(
prompt_tokens=input_tokens,
completion_tokens=total_tokens,
total_tokens=input_tokens + total_tokens
),
id=request_id,
model=self.model_id,
created=int(time.time())
)

await asyncio.sleep(0)

Expand All @@ -242,20 +254,27 @@ async def generate(
logger.info(f" Generated tokens: {total_tokens}")
generation_time = end_time - first_token_time if first_token_time else 0
logger.info(f" Tokens per second: {total_tokens / generation_time:.2f}")
yield {
"choices": [{
"delta": {"content": ""},
"finish_reason": "stop"
}],
"created": int(time.time()),
"model": self.model_id,
"id": request_id,
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": total_tokens,
"total_tokens": input_tokens + total_tokens
}
}

yield LLMResponse(
choices=[
LLMChoice(
delta=LLMMessage(
role="assistant",
content=""
),
index=0,
finish_reason="stop"
)
],
tokens_used=LLMTokenUsage(
prompt_tokens=input_tokens,
completion_tokens=total_tokens,
total_tokens=input_tokens + total_tokens
),
id=request_id,
model=self.model_id,
created=int(time.time())
)

except Exception as e:
if "CUDA out of memory" in str(e):
Expand Down
62 changes: 35 additions & 27 deletions runner/app/routes/llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import os
import time
from typing import Union
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, LLMRequest, LLMResponse, http_error
from app.routes.utils import HTTPError, LLMRequest, LLMChoice, LLMMessage, LLMResponse, http_error
import json

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
Expand All @@ -20,23 +19,24 @@
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post(
"/llm",
response_model=LLMResponse,
response_model=LLMResponse
,
responses=RESPONSES,
operation_id="genLLM",
description="Generate text using a language model.",
summary="LLM",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "llm"},
)
@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False)
@router.post("/llm/", response_model=LLMResponse
, responses=RESPONSES, include_in_schema=False)
async def llm(
request: LLMRequest,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
) -> Union[LLMResponse, JSONResponse, StreamingResponse]:
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
Expand Down Expand Up @@ -71,24 +71,31 @@ async def llm(
else:
full_response = ""
last_chunk = None

async for chunk in generator:
if isinstance(chunk, dict):
if "choices" in chunk:
if "delta" in chunk["choices"][0]:
full_response += chunk["choices"][0]["delta"].get(
"content", "")
last_chunk = chunk
if chunk.choices and chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
last_chunk = chunk

usage = last_chunk.get("usage", {})

return LLMResponse(
response=full_response,
tokens_used=usage.get("total_tokens", 0),
id=last_chunk.get("id", ""),
model=last_chunk.get("model", pipeline.model_id),
created=last_chunk.get("created", int(time.time()))
)
if last_chunk:
# Return the final response with accumulated text
return LLMResponse(
choices=[
LLMChoice(
message=LLMMessage(
role="assistant",
content=full_response
),
index=0,
finish_reason="stop"
)
],
tokens_used=last_chunk.tokens_used,
id=last_chunk.id,
model=last_chunk.model,
created=last_chunk.created
)
else:
raise ValueError("No response generated")

except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
Expand All @@ -101,12 +108,13 @@ async def llm(
async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict):
if "choices" in chunk:
if isinstance(chunk, LLMResponse):
if len(chunk.choices) > 0:
# Regular streaming chunk or final chunk
yield f"data: {json.dumps(chunk)}\n\n"
if chunk["choices"][0].get("finish_reason") == "stop":
yield f"data: {chunk.model_dump_json()}\n\n"
if chunk.choices[0].finish_reason == "stop":
break
# Signal end of stream
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
Expand Down
43 changes: 35 additions & 8 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,41 @@ class LLMMessage(BaseModel):
content: str


class LLMBaseChoice(BaseModel):
index: int
finish_reason: str = "" # Needs OpenAPI 3.1 support to make optional


class LLMTokenUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int

class LLMChoice(LLMBaseChoice):
delta: LLMMessage = None
message: LLMMessage = None

class LLMResponse(BaseModel):
id: str
model: str
created: int
tokens_used: LLMTokenUsage
choices: List[LLMChoice]


# class LLMStreamChoice(LLMBaseChoice):
# delta: LLMMessage

# class LLMNonStreamChoice(LLMBaseChoice):
# message: LLMMessage

# class LLMStreamResponse(LLMBaseResponse):
# choices: List[LLMStreamChoice]

# class LLMNonStreamResponse(LLMBaseResponse):
# choices: List[LLMNonStreamChoice]


class LLMRequest(BaseModel):
messages: List[LLMMessage]
model: str = ""
Expand All @@ -87,14 +122,6 @@ class LLMRequest(BaseModel):
stream: bool = False


class LLMResponse(BaseModel):
response: str
tokens_used: int
id: str
model: str
created: int


class ImageToTextResponse(BaseModel):
"""Response model for text generation."""

Expand Down
57 changes: 45 additions & 12 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,7 @@ components:
AudioResponse:
properties:
audio:
allOf:
- $ref: '#/components/schemas/MediaURL'
$ref: '#/components/schemas/MediaURL'
description: The generated audio.
type: object
required:
Expand Down Expand Up @@ -827,8 +826,7 @@ components:
HTTPError:
properties:
detail:
allOf:
- $ref: '#/components/schemas/APIError'
$ref: '#/components/schemas/APIError'
description: Detailed error information.
type: object
required:
Expand Down Expand Up @@ -868,6 +866,23 @@ components:
- text
title: ImageToTextResponse
description: Response model for text generation.
LLMChoice:
properties:
index:
type: integer
title: Index
finish_reason:
type: string
title: Finish Reason
default: ''
delta:
$ref: '#/components/schemas/LLMMessage'
message:
$ref: '#/components/schemas/LLMMessage'
type: object
required:
- index
title: LLMChoice
LLMMessage:
properties:
role:
Expand Down Expand Up @@ -918,12 +933,6 @@ components:
title: LLMRequest
LLMResponse:
properties:
response:
type: string
title: Response
tokens_used:
type: integer
title: Tokens Used
id:
type: string
title: Id
Expand All @@ -933,14 +942,38 @@ components:
created:
type: integer
title: Created
tokens_used:
$ref: '#/components/schemas/LLMTokenUsage'
choices:
items:
$ref: '#/components/schemas/LLMChoice'
type: array
title: Choices
type: object
required:
- response
- tokens_used
- id
- model
- created
- tokens_used
- choices
title: LLMResponse
LLMTokenUsage:
properties:
prompt_tokens:
type: integer
title: Prompt Tokens
completion_tokens:
type: integer
title: Completion Tokens
total_tokens:
type: integer
title: Total Tokens
type: object
required:
- prompt_tokens
- completion_tokens
- total_tokens
title: LLMTokenUsage
LiveVideoToVideoParams:
properties:
subscribe_url:
Expand Down
Loading
Loading