From 18cce4fc31052ae5085abbccc414a6979468946f Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Sun, 1 Sep 2024 06:52:15 -0700 Subject: [PATCH] fix return values from countgd endpoint --- vision_agent/tools/tool_utils.py | 4 ++-- vision_agent/tools/tools.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 2a260c41..185563a4 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,7 +1,7 @@ import inspect import logging import os -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd from IPython.display import display @@ -34,7 +34,7 @@ def send_inference_request( files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, -) -> Union[Dict[str, Any], List[Dict[str, Any]]]: +) -> Dict[str, Any]: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload try: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index a4eee6ac..08bf0370 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -540,10 +540,10 @@ def countgd_counting( "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_counting"} - data: List[Dict[str, Any]] = send_inference_request( + resp: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload - ) - return data + ) # type: ignore + return resp["data"] def countgd_example_based_counting( @@ -589,10 +589,10 @@ def countgd_example_based_counting( "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_example_based_counting"} - data: List[Dict[str, Any]] = send_inference_request( + resp: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload - ) - return data + ) # type: ignore + return resp["data"] def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: