@@ -150,10 +150,14 @@ def _api_to_ai_message(self, resp: dict) -> AIMessage:
150150 # Store usage metadata in additional_kwargs since AIMessage doesn't have usage_metadata parameter
151151 if usage_metadata :
152152 additional_kwargs ["usage_metadata" ] = usage_metadata
153+ # Add model_name to response_metadata for usage tracking
154+ response_metadata = resp .copy ()
155+ response_metadata ["model_name" ] = resp .get ("model" , self ._get_model ())
153156 return AIMessage (
154157 content = content ,
155158 additional_kwargs = additional_kwargs ,
156- response_metadata = resp ,
159+ response_metadata = response_metadata ,
160+ usage_metadata = usage_metadata if usage_metadata else None ,
157161 )
158162
159163 def _validate_config (self ) -> None :
@@ -280,8 +284,8 @@ def response_generator() -> Generator[bytes, None, None]:
280284 else :
281285 raise RuntimeError (f"Heroku Inference API stream call failed after { max_retries } retries: { last_exc } " )
282286
283- def _parse_sse_event (self , event : sseclient .Event ) -> Optional [str ]:
284- """Parse a single SSE event and extract content."""
287+ def _parse_sse_event (self , event : sseclient .Event ) -> Optional [Dict [ str , Any ] ]:
288+ """Parse a single SSE event and extract content and metadata ."""
285289 try :
286290 # Handle the special "[DONE]" message
287291 if event .data == "[DONE]" :
@@ -291,7 +295,23 @@ def _parse_sse_event(self, event: sseclient.Event) -> Optional[str]:
291295 # For streaming, use 'delta' instead of 'message'
292296 choice = data ["choices" ][0 ]
293297 delta = choice .get ("delta" , {})
294- return delta .get ("content" , "" )
298+ content = delta .get ("content" , "" )
299+
300+ # Extract usage metadata if available
301+ usage = data .get ("usage" , {})
302+ usage_metadata = None
303+ if usage :
304+ usage_metadata = {
305+ "input_tokens" : usage .get ("prompt_tokens" ),
306+ "output_tokens" : usage .get ("completion_tokens" ),
307+ "total_tokens" : usage .get ("total_tokens" ),
308+ }
309+
310+ return {
311+ "content" : content ,
312+ "usage_metadata" : usage_metadata ,
313+ "response_metadata" : data ,
314+ }
295315 except (json .JSONDecodeError , KeyError , IndexError ):
296316 # Skip malformed JSON lines or missing data
297317 return None
@@ -309,9 +329,21 @@ def _stream(
309329 client = self ._make_streaming_request (payload )
310330 try :
311331 for event in client .events ():
312- content = self ._parse_sse_event (event )
313- if content is not None :
314- ai_msg_chunk = AIMessageChunk (content = content )
332+ parsed_event = self ._parse_sse_event (event )
333+ if parsed_event is not None :
334+ content = parsed_event ["content" ]
335+ usage_metadata = parsed_event .get ("usage_metadata" )
336+ response_metadata = parsed_event .get ("response_metadata" , {})
337+
338+ # Add model_name to response_metadata for usage tracking
339+ if response_metadata :
340+ response_metadata ["model_name" ] = response_metadata .get ("model" , self ._get_model ())
341+
342+ ai_msg_chunk = AIMessageChunk (
343+ content = content ,
344+ usage_metadata = usage_metadata ,
345+ response_metadata = {"model_name" : response_metadata .get ("model" , self ._get_model ())} if response_metadata else {},
346+ )
315347 chunk = ChatGenerationChunk (message = ai_msg_chunk )
316348 if run_manager :
317349 run_manager .on_llm_new_token (content , chunk = chunk )
0 commit comments