diff --git a/runner/app/pipelines/llm.py b/runner/app/pipelines/llm.py index 6cd7780b..597e7745 100644 --- a/runner/app/pipelines/llm.py +++ b/runner/app/pipelines/llm.py @@ -11,7 +11,7 @@ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams from huggingface_hub import file_download from transformers import AutoConfig -from app.routes.utils import LLMResponse, LLMChoice, LLMMessage +from app.routes.utils import LLMResponse, LLMChoice, LLMMessage, LLMTokenUsage logger = logging.getLogger(__name__) @@ -228,11 +228,14 @@ async def generate( role="assistant", content=delta ), - index=0, - finish_reason=None + index=0 ) ], - tokens_used=total_tokens, + tokens_used=LLMTokenUsage( + input_tokens=input_tokens, + generated_tokens=total_tokens, + total_tokens=input_tokens + total_tokens + ), id=request_id, model=self.model_id, created=int(time.time()) @@ -263,7 +266,11 @@ async def generate( finish_reason="stop" ) ], - tokens_used=input_tokens + total_tokens, + tokens_used=LLMTokenUsage( + input_tokens=input_tokens, + generated_tokens=total_tokens, + total_tokens=input_tokens + total_tokens + ), id=request_id, model=self.model_id, created=int(time.time()) diff --git a/runner/app/routes/llm.py b/runner/app/routes/llm.py index 0620cc29..6cc7a860 100644 --- a/runner/app/routes/llm.py +++ b/runner/app/routes/llm.py @@ -107,11 +107,11 @@ async def llm( async def stream_generator(generator): try: async for chunk in generator: - if isinstance(chunk, dict): - if "choices" in chunk: + if isinstance(chunk, LLMResponse): + if len(chunk.choices) > 0: # Regular streaming chunk or final chunk - yield f"data: {json.dumps(chunk)}\n\n" - if chunk["choices"][0].get("finish_reason") == "stop": + yield f"data: {chunk.model_dump_json()}\n\n" + if chunk.choices[0].finish_reason == "stop": break # Signal end of stream yield "data: [DONE]\n\n" diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 3723433c..2df4fded 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -80,10 +80,12 @@ class LLMMessage(BaseModel): class LLMChoice(BaseModel): delta: LLMMessage index: int - finish_reason: str = Field( - default=None, - nullable=True, - ) + finish_reason: str = "" + +class LLMTokenUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int class LLMRequest(BaseModel): messages: List[LLMMessage] @@ -97,7 +99,7 @@ class LLMRequest(BaseModel): class LLMResponse(BaseModel): choices: List[LLMChoice] - tokens_used: int + tokens_used: LLMTokenUsage id: str model: str created: int