From abd62ecfaccc54b8ad5fe37947c0c678bb92ddc2 Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Tue, 2 Jul 2024 18:33:54 -0700 Subject: [PATCH] Support custom tool endpoint (#161) * Allow custom tool endpoint * Fix lint errors * Fix format --- vision_agent/tools/tool_utils.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index de889e30..667ff722 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -18,20 +18,29 @@ def send_inference_request( ) -> Dict[str, Any]: if runtime_tag := os.environ.get("RUNTIME_TAG", ""): payload["runtime_tag"] = runtime_tag + url = f"{_LND_API_URL}/model/{endpoint_name}" + if "TOOL_ENDPOINT_URL" in os.environ: + url = os.environ["TOOL_ENDPOINT_URL"] + + headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY} + if "TOOL_ENDPOINT_AUTH" in os.environ: + headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] + headers.pop("apikey") + session = _create_requests_session( url=url, num_retry=3, - headers={ - "Content-Type": "application/json", - "apikey": _LND_API_KEY, - }, + headers=headers, ) res = session.post(url, json=payload) if res.status_code != 200: _LOGGER.error(f"Request failed: {res.status_code} {res.text}") raise ValueError(f"Request failed: {res.status_code} {res.text}") - return res.json()["data"] # type: ignore + + resp = res.json() + # TODO: consider making the response schema the same between below two sources + return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore def _create_requests_session(