From b79a3d2c4ad184cf93f8af21dcdaf3b4400670ba Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Thu, 25 Jul 2024 20:36:22 +0800 Subject: [PATCH] Capture model tool call trace in the notebook execution result (#178) * Capture model tool call trace in the notebook execution result * Add missing code --- vision_agent/tools/tool_utils.py | 72 ++++++++++++++++++++++---------- vision_agent/utils/exceptions.py | 9 ++++ vision_agent/utils/execute.py | 11 +++++ 3 files changed, 70 insertions(+), 22 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 667ff722..c65019fb 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,11 +1,15 @@ import logging import os -from typing import Any, Dict +from typing import Any, Dict, MutableMapping, Optional +from IPython.display import display +from pydantic import BaseModel from requests import Session from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry +from vision_agent.utils.exceptions import RemoteToolCallFailed +from vision_agent.utils.execute import Error, MimeType from vision_agent.utils.type_defs import LandingaiAPIKey _LOGGER = logging.getLogger(__name__) @@ -13,34 +17,58 @@ _LND_API_URL = "https://api.staging.landing.ai/v1/agent" +class ToolCallTrace(BaseModel): + endpoint_url: str + request: MutableMapping[str, Any] + response: MutableMapping[str, Any] + error: Optional[Error] + + def send_inference_request( payload: Dict[str, Any], endpoint_name: str ) -> Dict[str, Any]: - if runtime_tag := os.environ.get("RUNTIME_TAG", ""): - payload["runtime_tag"] = runtime_tag + try: + if runtime_tag := os.environ.get("RUNTIME_TAG", ""): + payload["runtime_tag"] = runtime_tag - url = f"{_LND_API_URL}/model/{endpoint_name}" - if "TOOL_ENDPOINT_URL" in os.environ: - url = os.environ["TOOL_ENDPOINT_URL"] + url = f"{_LND_API_URL}/model/{endpoint_name}" + if "TOOL_ENDPOINT_URL" in os.environ: + url = os.environ["TOOL_ENDPOINT_URL"] - headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY} - if "TOOL_ENDPOINT_AUTH" in os.environ: - headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] - headers.pop("apikey") + tool_call_trace = ToolCallTrace( + endpoint_url=url, + request=payload, + response={}, + error=None, + ) + headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY} + if "TOOL_ENDPOINT_AUTH" in os.environ: + headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] + headers.pop("apikey") - session = _create_requests_session( - url=url, - num_retry=3, - headers=headers, - ) - res = session.post(url, json=payload) - if res.status_code != 200: - _LOGGER.error(f"Request failed: {res.status_code} {res.text}") - raise ValueError(f"Request failed: {res.status_code} {res.text}") + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) + res = session.post(url, json=payload) + if res.status_code != 200: + tool_call_trace.error = Error( + name="RemoteToolCallFailed", + value=f"{res.status_code} - {res.text}", + traceback_raw=[], + ) + _LOGGER.error(f"Request failed: {res.status_code} {res.text}") + raise RemoteToolCallFailed(payload["tool"], res.status_code, res.text) - resp = res.json() - # TODO: consider making the response schema the same between below two sources - return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore + resp = res.json() + tool_call_trace.response = resp + # TODO: consider making the response schema the same between below two sources + return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore + finally: + trace = tool_call_trace.model_dump() + trace["type"] = "tool_call" + display({MimeType.APPLICATION_JSON: trace}, raw=True) def _create_requests_session( diff --git a/vision_agent/utils/exceptions.py b/vision_agent/utils/exceptions.py index 1ffd3e11..41f81dad 100644 --- a/vision_agent/utils/exceptions.py +++ b/vision_agent/utils/exceptions.py @@ -13,6 +13,15 @@ def __str__(self) -> str: return self.message +class RemoteToolCallFailed(Exception): + """Exception raised when an error occurs during a tool call.""" + + def __init__(self, tool_name: str, status_code: int, message: str): + self.message = ( + f"""Tool call ({tool_name}) failed due to {status_code} - {message}""" + ) + + class RemoteSandboxError(Exception): """Exception related to remote sandbox.""" diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 821064c8..aaea19cc 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -277,6 +277,17 @@ def traceback(self, return_clean_text: bool = True) -> str: text = "\n".join(self.traceback_raw) return _remove_escape_and_color_codes(text) if return_clean_text else text + @staticmethod + def from_exception(e: Exception) -> "Error": + """ + Creates an Error object from an exception. + """ + return Error( + name=e.__class__.__name__, + value=str(e), + traceback_raw=traceback.format_exception(type(e), e, e.__traceback__), + ) + class Execution(BaseModel): """