Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add count tool #216

Merged
merged 13 commits into from
Sep 4, 2024
7 changes: 7 additions & 0 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

env:
LANDINGAI_DEV_API_KEY: ${{ secrets.LANDINGAI_DEV_API_KEY }}

jobs:
unit_test:
name: Test
Expand Down Expand Up @@ -79,6 +83,9 @@ jobs:
- name: Test with pytest
run: |
poetry run pytest -v tests/integ
- name: Test with pytest, dev env
run: |
LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev

release:
name: Release
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions tests/integration_dev/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import skimage as ski

from vision_agent.tools import (
countgd_counting,
countgd_example_based_counting,
)


def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24


def test_countgd_example_based_counting() -> None:
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert len(result) == 24
9 changes: 4 additions & 5 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,19 @@
- Count the number of detected objects labeled as 'person'.
plan3:
- Load the image from the provided file path 'image.jpg'.
- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.

```python
from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
image = load_image("image.jpg")
owl_v2_out = owl_v2("person", image)

gsam_out = grounding_sam("person", image)
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]

loca_out = loca_zero_shot_counting(image)
loca_out = loca_out["count"]
cgd_out = countgd_counting(image)

final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
print(final_out)
```
"""
Expand Down
3 changes: 0 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: T.grounding_sam(params["prompt"], x)

def generate_zero_shot_counter(self, question: str) -> Callable:
return T.loca_zero_shot_counting

def generate_image_qa_tool(self, question: str) -> Callable:
return lambda x: T.git_vqa_v2(question, x)

Expand Down
3 changes: 3 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
countgd_counting,
countgd_example_based_counting,
ocr,
overlay_bounding_boxes,
overlay_heat_map,
overlay_segmentation_masks,
overlay_counting_results,
owl_v2,
save_image,
save_json,
Expand Down
146 changes: 95 additions & 51 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import inspect
import logging
import os
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple

import pandas as pd
Expand All @@ -13,6 +13,7 @@
from vision_agent.utils.exceptions import RemoteToolCallFailed
from vision_agent.utils.execute import Error, MimeType
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.tools_types import BoundingBoxes

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
Expand All @@ -34,61 +35,58 @@ def send_inference_request(
files: Optional[List[Tuple[Any, ...]]] = None,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
) -> Any:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
try:
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag

url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]

headers = {"apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)

url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]

tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)
headers = {"apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)
response = _call_post(url, payload, session, files, function_name)

if files is not None:
res = session.post(url, data=payload, files=files)
else:
res = session.post(url, json=payload)
if res.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{res.status_code} - {res.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
# TODO: function_name should be in metadata_payload
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]
raise RemoteToolCallFailed(function_name, res.status_code, res.text)

resp = res.json()
tool_call_trace.response = resp
# TODO: consider making the response schema the same between below two sources
return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)
# TODO: consider making the response schema the same between below two sources
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]


def send_task_inference_request(
payload: Dict[str, Any],
task_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Any:
url = f"{_LND_API_URL_v2}/{task_name}"
headers = {"apikey": _LND_API_KEY}
session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)

function_name = "unknown"
if metadata is not None and "function_name" in metadata:
function_name = metadata["function_name"]
response = _call_post(url, payload, session, files, function_name)
return response["data"]


def _create_requests_session(
Expand Down Expand Up @@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"

return data


def _call_post(
url: str,
payload: dict[str, Any],
session: Session,
files: Optional[List[Tuple[Any, ...]]] = None,
function_name: str = "unknown",
) -> Any:
try:
tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)

if files is not None:
response = session.post(url, data=payload, files=files)
else:
response = session.post(url, json=payload)

if response.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{response.status_code} - {response.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
raise RemoteToolCallFailed(
function_name, response.status_code, response.text
)

result = response.json()
tool_call_trace.response = result
return result
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)


def filter_bboxes_by_threshold(
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
bboxes: BoundingBoxes, threshold: float
) -> BoundingBoxes:
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
Loading
Loading