diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index d41e7592..3576e10c 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -1,10 +1,14 @@ name: CI + on: push: branches: [ main ] pull_request: branches: [ main ] +env: + LANDINGAI_DEV_API_KEY: ${{ secrets.LANDINGAI_DEV_API_KEY }} + jobs: unit_test: name: Test @@ -79,6 +83,9 @@ jobs: - name: Test with pytest run: | poetry run pytest -v tests/integ + - name: Test with pytest, dev env + run: | + LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev release: name: Release diff --git a/tests/integration_dev/__init__.py b/tests/integration_dev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_dev/test_tools.py b/tests/integration_dev/test_tools.py new file mode 100644 index 00000000..29262245 --- /dev/null +++ b/tests/integration_dev/test_tools.py @@ -0,0 +1,21 @@ +import skimage as ski + +from vision_agent.tools import ( + countgd_counting, + countgd_example_based_counting, +) + + +def test_countgd_counting() -> None: + img = ski.data.coins() + result = countgd_counting(image=img, prompt="coin") + assert len(result) == 24 + + +def test_countgd_example_based_counting() -> None: + img = ski.data.coins() + result = countgd_example_based_counting( + visual_prompts=[[85, 106, 122, 145]], + image=img, + ) + assert len(result) == 24 diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 185563a4..30ac659b 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,6 +1,6 @@ +import os import inspect import logging -import os from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd @@ -13,6 +13,7 @@ from vision_agent.utils.exceptions import RemoteToolCallFailed from vision_agent.utils.execute import Error, MimeType from vision_agent.utils.type_defs import LandingaiAPIKey +from vision_agent.tools.tools_types import BoundingBoxes _LOGGER = logging.getLogger(__name__) _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key) @@ -37,58 +38,55 @@ def send_inference_request( ) -> Dict[str, Any]: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload - try: - if runtime_tag := os.environ.get("RUNTIME_TAG", ""): - payload["runtime_tag"] = runtime_tag + if runtime_tag := os.environ.get("RUNTIME_TAG", ""): + payload["runtime_tag"] = runtime_tag + + url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" + if "TOOL_ENDPOINT_URL" in os.environ: + url = os.environ["TOOL_ENDPOINT_URL"] + + headers = {"apikey": _LND_API_KEY} + if "TOOL_ENDPOINT_AUTH" in os.environ: + headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] + headers.pop("apikey") + + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) - url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" - if "TOOL_ENDPOINT_URL" in os.environ: - url = os.environ["TOOL_ENDPOINT_URL"] + function_name = "unknown" + if "function_name" in payload: + function_name = payload["function_name"] + elif metadata_payload is not None and "function_name" in metadata_payload: + function_name = metadata_payload["function_name"] - tool_call_trace = ToolCallTrace( - endpoint_url=url, - request=payload, - response={}, - error=None, - ) - headers = {"apikey": _LND_API_KEY} - if "TOOL_ENDPOINT_AUTH" in os.environ: - headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] - headers.pop("apikey") - - session = _create_requests_session( - url=url, - num_retry=3, - headers=headers, - ) + response = _call_post(url, payload, session, files, function_name) - if files is not None: - res = session.post(url, data=payload, files=files) - else: - res = session.post(url, json=payload) - if res.status_code != 200: - tool_call_trace.error = Error( - name="RemoteToolCallFailed", - value=f"{res.status_code} - {res.text}", - traceback_raw=[], - ) - _LOGGER.error(f"Request failed: {res.status_code} {res.text}") - # TODO: function_name should be in metadata_payload - function_name = "unknown" - if "function_name" in payload: - function_name = payload["function_name"] - elif metadata_payload is not None and "function_name" in metadata_payload: - function_name = metadata_payload["function_name"] - raise RemoteToolCallFailed(function_name, res.status_code, res.text) - - resp = res.json() - tool_call_trace.response = resp - # TODO: consider making the response schema the same between below two sources - return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore - finally: - trace = tool_call_trace.model_dump() - trace["type"] = "tool_call" - display({MimeType.APPLICATION_JSON: trace}, raw=True) + # TODO: consider making the response schema the same between below two sources + return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"] + + +def send_task_inference_request( + payload: Dict[str, Any], + endpoint_name: str, + files: Optional[List[Tuple[Any, ...]]] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + url = f"{_LND_API_URL_v2}/{endpoint_name}" + headers = {"apikey": _LND_API_KEY} + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) + + function_name = "unknown" + if metadata is not None and "function_name" in metadata: + function_name = metadata["function_name"] + response = _call_post(url, payload, session, files, function_name) + return response["data"] def _create_requests_session( @@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]: data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}" return data + + +def _call_post( + url: str, + payload: dict[str, Any], + session: Session, + files: Optional[List[Tuple[Any, ...]]] = None, + function_name: str = "unknown", +) -> dict[str, Any]: + try: + tool_call_trace = ToolCallTrace( + endpoint_url=url, + request=payload, + response={}, + error=None, + ) + + if files is not None: + response = session.post(url, data=payload, files=files) + else: + response = session.post(url, json=payload) + + if response.status_code != 200: + tool_call_trace.error = Error( + name="RemoteToolCallFailed", + value=f"{response.status_code} - {response.text}", + traceback_raw=[], + ) + _LOGGER.error(f"Request failed: {response.status_code} {response.text}") + raise RemoteToolCallFailed( + function_name, response.status_code, response.text + ) + + result = response.json() + tool_call_trace.response = result + return result + finally: + trace = tool_call_trace.model_dump() + trace["type"] = "tool_call" + display({MimeType.APPLICATION_JSON: trace}, raw=True) + + +def filter_bboxes_by_threshold( + bboxes: BoundingBoxes, threshold: float +) -> BoundingBoxes: + return list(map(lambda bbox: bbox["score"] >= threshold, bboxes)) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 1ad6ea11..add35f5c 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -22,6 +22,8 @@ get_tools_df, get_tools_info, send_inference_request, + send_task_inference_request, + filter_bboxes_by_threshold, ) from vision_agent.tools.tools_types import ( BboxInput, @@ -30,6 +32,7 @@ Florencev2FtRequest, JobStatus, PromptTask, + ODResponseData, ) from vision_agent.utils import extract_frames_from_video from vision_agent.utils.exceptions import FineTuneModelIsNotReady @@ -527,24 +530,22 @@ def countgd_counting( ------- >>> countgd_counting("flower", image) [ - {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, - {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5}, - {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52}, - {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, + {'score': 0.49, 'label': 'flower', 'bounding_box': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bounding_box': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'flower', 'bounding_box': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'flower', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, ] """ - image_b64 = convert_to_b64(image) - payload = { - "image": image_b64, - "prompt": prompt, - "box_threshold": box_threshold, - } - metadata_payload = {"function_name": "countgd_counting"} - resp_data: List[Dict[str, Any]] = send_inference_request( - payload, "countgd", v2=True, metadata_payload=metadata_payload - ) # type: ignore - - return resp_data + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = {"prompts": [prompt]} + metadata = {"function_name": "countgd_counting"} + resp_data: List[Dict[str, Any]] = send_task_inference_request( + payload, "text-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] + bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) def countgd_example_based_counting( @@ -577,27 +578,25 @@ def countgd_example_based_counting( image=image ) [ - {'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}, - {'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5}, - {'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52}, - {'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58}, + {'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, ] """ - image_b64 = convert_to_b64(image) + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] visual_prompts = [ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts ] - payload = { - "image": image_b64, - "visual_prompts": visual_prompts, - "box_threshold": box_threshold, - } - metadata_payload = {"function_name": "countgd_example_based_counting"} - resp_data: List[Dict[str, Any]] = send_inference_request( - payload, "countgd", v2=True, metadata_payload=metadata_payload - ) # type: ignore - - return resp_data + payload = {"visual_prompts": json.loads(visual_prompts)} + metadata = {"function_name": "countgd_example_based_counting"} + resp_data: List[Dict[str, Any]] = send_task_inference_request( + payload, "visual-prompts-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] + bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index aeb45c95..aa6f5f68 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -1,7 +1,8 @@ from uuid import UUID from enum import Enum -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Annotated +from annotated_types import Len from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo @@ -82,3 +83,19 @@ class JobStatus(str, Enum): SUCCEEDED = "SUCCEEDED" FAILED = "FAILED" STOPPED = "STOPPED" + + +BoundingBox = Annotated[list[int | float], Len(min_length=4, max_length=4)] + + +class ODResponseData(BaseModel): + label: str + score: float + bbox: BoundingBox = Field(alias="bounding_box") + + model_config = ConfigDict( + populate_by_name=True, + ) + + +BoundingBoxes = list[ODResponseData]