Skip to content

Commit

Permalink
adapt countgd + dev integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayof committed Sep 2, 2024
1 parent 7dd6a37 commit d51344b
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 83 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

env:
LANDINGAI_DEV_API_KEY: ${{ secrets.LANDINGAI_DEV_API_KEY }}

jobs:
unit_test:
name: Test
Expand Down Expand Up @@ -79,6 +83,9 @@ jobs:
- name: Test with pytest
run: |
poetry run pytest -v tests/integ
- name: Test with pytest, dev env
run: |
LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev
release:
name: Release
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions tests/integration_dev/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import skimage as ski

from vision_agent.tools import (
countgd_counting,
countgd_example_based_counting,
)


def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24


def test_countgd_example_based_counting() -> None:
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert len(result) == 24
144 changes: 94 additions & 50 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import inspect
import logging
import os
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple

import pandas as pd
Expand All @@ -13,6 +13,7 @@
from vision_agent.utils.exceptions import RemoteToolCallFailed
from vision_agent.utils.execute import Error, MimeType
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.tools_types import BoundingBoxes

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
Expand All @@ -37,58 +38,55 @@ def send_inference_request(
) -> Dict[str, Any]:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag

url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]

headers = {"apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)

url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]

tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)
headers = {"apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)
response = _call_post(url, payload, session, files, function_name)

if files is not None:
res = session.post(url, data=payload, files=files)
else:
res = session.post(url, json=payload)
if res.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{res.status_code} - {res.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
# TODO: function_name should be in metadata_payload
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]
raise RemoteToolCallFailed(function_name, res.status_code, res.text)

resp = res.json()
tool_call_trace.response = resp
# TODO: consider making the response schema the same between below two sources
return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)
# TODO: consider making the response schema the same between below two sources
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]


def send_task_inference_request(
payload: Dict[str, Any],
endpoint_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
url = f"{_LND_API_URL_v2}/{endpoint_name}"
headers = {"apikey": _LND_API_KEY}
session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)

function_name = "unknown"
if metadata is not None and "function_name" in metadata:
function_name = metadata["function_name"]
response = _call_post(url, payload, session, files, function_name)
return response["data"]


def _create_requests_session(
Expand Down Expand Up @@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"

return data


def _call_post(
url: str,
payload: dict[str, Any],
session: Session,
files: Optional[List[Tuple[Any, ...]]] = None,
function_name: str = "unknown",
) -> dict[str, Any]:
try:
tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)

if files is not None:
response = session.post(url, data=payload, files=files)
else:
response = session.post(url, json=payload)

if response.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{response.status_code} - {response.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
raise RemoteToolCallFailed(
function_name, response.status_code, response.text
)

result = response.json()
tool_call_trace.response = result
return result
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)


def filter_bboxes_by_threshold(
bboxes: BoundingBoxes, threshold: float
) -> BoundingBoxes:
return list(map(lambda bbox: bbox["score"] >= threshold, bboxes))
63 changes: 31 additions & 32 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
get_tools_df,
get_tools_info,
send_inference_request,
send_task_inference_request,
filter_bboxes_by_threshold,
)
from vision_agent.tools.tools_types import (
BboxInput,
Expand All @@ -30,6 +32,7 @@
Florencev2FtRequest,
JobStatus,
PromptTask,
ODResponseData,
)
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
Expand Down Expand Up @@ -527,24 +530,22 @@ def countgd_counting(
-------
>>> countgd_counting("flower", image)
[
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
{'score': 0.49, 'label': 'flower', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'flower', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
{'score': 0.78, 'label': 'flower', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
{'score': 0.98, 'label': 'flower', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
]
"""
image_b64 = convert_to_b64(image)
payload = {
"image": image_b64,
"prompt": prompt,
"box_threshold": box_threshold,
}
metadata_payload = {"function_name": "countgd_counting"}
resp_data: List[Dict[str, Any]] = send_inference_request(
payload, "countgd", v2=True, metadata_payload=metadata_payload
) # type: ignore

return resp_data
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
payload = {"prompts": [prompt]}
metadata = {"function_name": "countgd_counting"}
resp_data: List[Dict[str, Any]] = send_task_inference_request(
payload, "text-to-object-detection", files=files, metadata=metadata
)
bboxes_per_frame = resp_data[0]
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
return filter_bboxes_by_threshold(bboxes_formatted, box_threshold)


def countgd_example_based_counting(
Expand Down Expand Up @@ -577,27 +578,25 @@ def countgd_example_based_counting(
image=image
)
[
{'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5},
{'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52},
{'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58},
{'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
{'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
]
"""
image_b64 = convert_to_b64(image)
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
visual_prompts = [
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
]
payload = {
"image": image_b64,
"visual_prompts": visual_prompts,
"box_threshold": box_threshold,
}
metadata_payload = {"function_name": "countgd_example_based_counting"}
resp_data: List[Dict[str, Any]] = send_inference_request(
payload, "countgd", v2=True, metadata_payload=metadata_payload
) # type: ignore

return resp_data
payload = {"visual_prompts": json.loads(visual_prompts)}
metadata = {"function_name": "countgd_example_based_counting"}
resp_data: List[Dict[str, Any]] = send_task_inference_request(
payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
)
bboxes_per_frame = resp_data[0]
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
return filter_bboxes_by_threshold(bboxes_formatted, box_threshold)


def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
Expand Down
19 changes: 18 additions & 1 deletion vision_agent/tools/tools_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from uuid import UUID
from enum import Enum
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Annotated

from annotated_types import Len
from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo


Expand Down Expand Up @@ -82,3 +83,19 @@ class JobStatus(str, Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
STOPPED = "STOPPED"


BoundingBox = Annotated[list[int | float], Len(min_length=4, max_length=4)]


class ODResponseData(BaseModel):
label: str
score: float
bbox: BoundingBox = Field(alias="bounding_box")

model_config = ConfigDict(
populate_by_name=True,
)


BoundingBoxes = list[ODResponseData]

0 comments on commit d51344b

Please sign in to comment.