Skip to content

Commit

Permalink
feat: add count tool (#216)
Browse files Browse the repository at this point in the history
* Adding countgd as default counting tool

* fix mypy errors

* added viz for counting tool

* adjust call

* fix bbox coords outside the image, countgd return types

* fix return values from countgd endpoint

* correct output format for cgd

* adapt countgd + dev integration test

* add model

* linter

* linter

* fixed keys in the example string, add suppot for multi-class

---------

Co-authored-by: Dayanne Fernandes <[email protected]>
  • Loading branch information
shankar-vision-eng and Dayof authored Sep 4, 2024
1 parent e699553 commit 103832d
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 70 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

env:
LANDINGAI_DEV_API_KEY: ${{ secrets.LANDINGAI_DEV_API_KEY }}

jobs:
unit_test:
name: Test
Expand Down Expand Up @@ -79,6 +83,9 @@ jobs:
- name: Test with pytest
run: |
poetry run pytest -v tests/integ
- name: Test with pytest, dev env
run: |
LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev
release:
name: Release
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions tests/integration_dev/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import skimage as ski

from vision_agent.tools import (
countgd_counting,
countgd_example_based_counting,
)


def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24


def test_countgd_example_based_counting() -> None:
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert len(result) == 24
9 changes: 4 additions & 5 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,19 @@
- Count the number of detected objects labeled as 'person'.
plan3:
- Load the image from the provided file path 'image.jpg'.
- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
```python
from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
image = load_image("image.jpg")
owl_v2_out = owl_v2("person", image)
gsam_out = grounding_sam("person", image)
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
loca_out = loca_zero_shot_counting(image)
loca_out = loca_out["count"]
cgd_out = countgd_counting(image)
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
print(final_out)
```
"""
Expand Down
3 changes: 0 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,6 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: T.grounding_sam(params["prompt"], x)

def generate_zero_shot_counter(self, question: str) -> Callable:
return T.loca_zero_shot_counting

def generate_image_qa_tool(self, question: str) -> Callable:
return lambda x: T.git_vqa_v2(question, x)

Expand Down
3 changes: 3 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
countgd_counting,
countgd_example_based_counting,
ocr,
overlay_bounding_boxes,
overlay_heat_map,
overlay_segmentation_masks,
overlay_counting_results,
owl_v2,
save_image,
save_json,
Expand Down
146 changes: 95 additions & 51 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import inspect
import logging
import os
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple

import pandas as pd
Expand All @@ -13,6 +13,7 @@
from vision_agent.utils.exceptions import RemoteToolCallFailed
from vision_agent.utils.execute import Error, MimeType
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.tools_types import BoundingBoxes

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
Expand All @@ -34,61 +35,58 @@ def send_inference_request(
files: Optional[List[Tuple[Any, ...]]] = None,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
) -> Any:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag

url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]

headers = {"apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)

url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]

tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)
headers = {"apikey": _LND_API_KEY}
if "TOOL_ENDPOINT_AUTH" in os.environ:
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
headers.pop("apikey")

session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)
response = _call_post(url, payload, session, files, function_name)

if files is not None:
res = session.post(url, data=payload, files=files)
else:
res = session.post(url, json=payload)
if res.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{res.status_code} - {res.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
# TODO: function_name should be in metadata_payload
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]
raise RemoteToolCallFailed(function_name, res.status_code, res.text)

resp = res.json()
tool_call_trace.response = resp
# TODO: consider making the response schema the same between below two sources
return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)
# TODO: consider making the response schema the same between below two sources
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]


def send_task_inference_request(
payload: Dict[str, Any],
task_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Any:
url = f"{_LND_API_URL_v2}/{task_name}"
headers = {"apikey": _LND_API_KEY}
session = _create_requests_session(
url=url,
num_retry=3,
headers=headers,
)

function_name = "unknown"
if metadata is not None and "function_name" in metadata:
function_name = metadata["function_name"]
response = _call_post(url, payload, session, files, function_name)
return response["data"]


def _create_requests_session(
Expand Down Expand Up @@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"

return data


def _call_post(
url: str,
payload: dict[str, Any],
session: Session,
files: Optional[List[Tuple[Any, ...]]] = None,
function_name: str = "unknown",
) -> Any:
try:
tool_call_trace = ToolCallTrace(
endpoint_url=url,
request=payload,
response={},
error=None,
)

if files is not None:
response = session.post(url, data=payload, files=files)
else:
response = session.post(url, json=payload)

if response.status_code != 200:
tool_call_trace.error = Error(
name="RemoteToolCallFailed",
value=f"{response.status_code} - {response.text}",
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
raise RemoteToolCallFailed(
function_name, response.status_code, response.text
)

result = response.json()
tool_call_trace.response = result
return result
finally:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)


def filter_bboxes_by_threshold(
bboxes: BoundingBoxes, threshold: float
) -> BoundingBoxes:
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
Loading

0 comments on commit 103832d

Please sign in to comment.