Skip to content

Commit 1d711a9

Browse files
committed
fix: applied langchain testing standards
1 parent d8b9018 commit 1d711a9

File tree

5 files changed

+729
-57
lines changed

5 files changed

+729
-57
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
run: poetry run mypy .
4949

5050
- name: Run tests
51-
run: poetry run pytest --maxfail=1 --disable-warnings --tb=short
51+
run: poetry run pytest --maxfail=1 --disable-warnings --tb=short --ignore=tests/integration_tests/
5252

5353
- name: Upload pytest results (optional)
5454
if: always()

langchain_heroku/chat_models.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,14 @@ def _api_to_ai_message(self, resp: dict) -> AIMessage:
150150
# Store usage metadata in additional_kwargs since AIMessage doesn't have usage_metadata parameter
151151
if usage_metadata:
152152
additional_kwargs["usage_metadata"] = usage_metadata
153+
# Add model_name to response_metadata for usage tracking
154+
response_metadata = resp.copy()
155+
response_metadata["model_name"] = resp.get("model", self._get_model())
153156
return AIMessage(
154157
content=content,
155158
additional_kwargs=additional_kwargs,
156-
response_metadata=resp,
159+
response_metadata=response_metadata,
160+
usage_metadata=usage_metadata if usage_metadata else None,
157161
)
158162

159163
def _validate_config(self) -> None:
@@ -280,8 +284,8 @@ def response_generator() -> Generator[bytes, None, None]:
280284
else:
281285
raise RuntimeError(f"Heroku Inference API stream call failed after {max_retries} retries: {last_exc}")
282286

283-
def _parse_sse_event(self, event: sseclient.Event) -> Optional[str]:
284-
"""Parse a single SSE event and extract content."""
287+
def _parse_sse_event(self, event: sseclient.Event) -> Optional[Dict[str, Any]]:
288+
"""Parse a single SSE event and extract content and metadata."""
285289
try:
286290
# Handle the special "[DONE]" message
287291
if event.data == "[DONE]":
@@ -291,7 +295,23 @@ def _parse_sse_event(self, event: sseclient.Event) -> Optional[str]:
291295
# For streaming, use 'delta' instead of 'message'
292296
choice = data["choices"][0]
293297
delta = choice.get("delta", {})
294-
return delta.get("content", "")
298+
content = delta.get("content", "")
299+
300+
# Extract usage metadata if available
301+
usage = data.get("usage", {})
302+
usage_metadata = None
303+
if usage:
304+
usage_metadata = {
305+
"input_tokens": usage.get("prompt_tokens"),
306+
"output_tokens": usage.get("completion_tokens"),
307+
"total_tokens": usage.get("total_tokens"),
308+
}
309+
310+
return {
311+
"content": content,
312+
"usage_metadata": usage_metadata,
313+
"response_metadata": data,
314+
}
295315
except (json.JSONDecodeError, KeyError, IndexError):
296316
# Skip malformed JSON lines or missing data
297317
return None
@@ -309,9 +329,21 @@ def _stream(
309329
client = self._make_streaming_request(payload)
310330
try:
311331
for event in client.events():
312-
content = self._parse_sse_event(event)
313-
if content is not None:
314-
ai_msg_chunk = AIMessageChunk(content=content)
332+
parsed_event = self._parse_sse_event(event)
333+
if parsed_event is not None:
334+
content = parsed_event["content"]
335+
usage_metadata = parsed_event.get("usage_metadata")
336+
response_metadata = parsed_event.get("response_metadata", {})
337+
338+
# Add model_name to response_metadata for usage tracking
339+
if response_metadata:
340+
response_metadata["model_name"] = response_metadata.get("model", self._get_model())
341+
342+
ai_msg_chunk = AIMessageChunk(
343+
content=content,
344+
usage_metadata=usage_metadata,
345+
response_metadata={"model_name": response_metadata.get("model", self._get_model())} if response_metadata else {},
346+
)
315347
chunk = ChatGenerationChunk(message=ai_msg_chunk)
316348
if run_manager:
317349
run_manager.on_llm_new_token(content, chunk=chunk)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ addopts = "--strict-markers --strict-config --durations=5"
3838
markers = [
3939
"compile: mark placeholder test used to compile integration tests without running them",
4040
"slow: mark test as slow running",
41+
"production: marks tests as production endpoint tests",
4142
]
4243

4344
[tool.poetry.group.test]

0 commit comments

Comments
 (0)