Skip to content

Commit

Permalink
Capture model tool call trace in the notebook execution result (#178)
Browse files Browse the repository at this point in the history
* Capture model tool call trace in the notebook execution result

* Add missing code
  • Loading branch information
humpydonkey authored Jul 25, 2024
1 parent 915c895 commit b79a3d2
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
72 changes: 50 additions & 22 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,74 @@
import logging
import os
from typing import Any, Dict
from typing import Any, Dict, MutableMapping, Optional

from IPython.display import display
from pydantic import BaseModel
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from vision_agent.utils.exceptions import RemoteToolCallFailed
from vision_agent.utils.execute import Error, MimeType
from vision_agent.utils.type_defs import LandingaiAPIKey

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
_LND_API_URL = "https://api.staging.landing.ai/v1/agent"


class ToolCallTrace(BaseModel):
endpoint_url: str
request: MutableMapping[str, Any]
response: MutableMapping[str, Any]
error: Optional[Error]


def send_inference_request(
payload: Dict[str, Any], endpoint_name: str
) -> Dict[str, Any]:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag

url = f"{_LND_API_URL}/model/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]
url = f"{_LND_API_URL}/model/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]

headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")
tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)
headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)
res = session.post(url, json=payload)
if res.status_code != 200:
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
raise ValueError(f"Request failed: {res.status_code} {res.text}")
session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)
res = session.post(url, json=payload)
if res.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{res.status_code} - {res.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
raise RemoteToolCallFailed(payload["tool"], res.status_code, res.text)

resp = res.json()
# TODO: consider making the response schema the same between below two sources
return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
resp = res.json()
tool_call_trace.response = resp
# TODO: consider making the response schema the same between below two sources
return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)


def _create_requests_session(
Expand Down
9 changes: 9 additions & 0 deletions vision_agent/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def __str__(self) -> str:
return self.message


class RemoteToolCallFailed(Exception):
"""Exception raised when an error occurs during a tool call."""

def __init__(self, tool_name: str, status_code: int, message: str):
self.message = (
f"""Tool call ({tool_name}) failed due to {status_code} - {message}"""
)


class RemoteSandboxError(Exception):
"""Exception related to remote sandbox."""

Expand Down
11 changes: 11 additions & 0 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,17 @@ def traceback(self, return_clean_text: bool = True) -> str:
text = "\n".join(self.traceback_raw)
return _remove_escape_and_color_codes(text) if return_clean_text else text

@staticmethod
def from_exception(e: Exception) -> "Error":
"""
Creates an Error object from an exception.
"""
return Error(
name=e.__class__.__name__,
value=str(e),
traceback_raw=traceback.format_exception(type(e), e, e.__traceback__),
)


class Execution(BaseModel):
"""
Expand Down

0 comments on commit b79a3d2

Please sign in to comment.