Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 22, 2024
1 parent 37f0e75 commit d5286d1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 49 deletions.
10 changes: 9 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
from .tools import (
CLIP,
TOOLS,
Counter,
Crop,
GroundingDINO,
GroundingSAM,
Tool,
)
72 changes: 24 additions & 48 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class CLIP(Tool):
Examples::
>>> import vision_agent as va
>>> clip = va.tools.CLIP()
>>> clip(["red line", "yellow dot"], "examples/img/ct_scan1.jpg"))
>>> [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]]
>>> clip(["red line", "yellow dot"], "ct_scan1.jpg"))
>>> [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}]
"""

_ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws"
Expand Down Expand Up @@ -115,6 +115,18 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict


class GroundingDINO(Tool):
r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as
category names or referring expressions.
Examples::
>>> import vision_agent as va
>>> t = va.tools.GroundingDINO()
>>> t("red line. yellow dot", "ct_scan1.jpg")
>>> [{'labels': ['red line', 'yellow dot'],
>>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]],
>>> 'scores': [0.98, 0.02]}]
"""

_ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws"

name = "grounding_dino_"
Expand Down Expand Up @@ -177,18 +189,21 @@ class GroundingSAM(Tool):
inputs such as category names or referring expressions.
Examples::
>>> from vision_agent.tools import tools
>>> t = tools.GroundingSAM(["red line", "yellow dot", "none"])
>>> t("examples/img/ct_scan1.jpg")
>>> [{'label': 'yellow dot', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
>>> import vision_agent as va
>>> t = va.tools.GroundingSAM()
>>> t(["red line", "yellow dot"], ct_scan1.jpg"])
>>> [{'labels': ['yellow dot', 'red line'],
>>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]],
>>> 'masks': [array([[0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0],
>>> ...,
>>> [0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)},
>>> array([[0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0],
>>> ...,
>>> [1, 1, 1, ..., 1, 1, 1],
>>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}]
>>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}]
"""

_ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws"
Expand Down Expand Up @@ -287,7 +302,6 @@ class Counter(Tool):

def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
resp = GroundingDINO()(prompt, image)
__import__("ipdb").set_trace()
return dict(CounterClass(resp[0]["labels"]))


Expand Down Expand Up @@ -320,51 +334,13 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
int(bbox[2] * width),
int(bbox[3] * height),
]
cropped_image = pil_image.crop(bbox)
cropped_image = pil_image.crop(bbox) # type: ignore
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
cropped_image.save(tmp.name)

return tmp.name


class ImageSearch(Tool):
name = "image_search_"
description = "'image_search_' searches for images similar to the input image."
usage = {
"required_parameters": [{"name": "image", "type": "str"}],
"examples": [
{
"scenario": "Can you find images similar to the image? Image name: image.jpg",
"parameters": {"image": "image.jpg"},
}
],
}

def __call__(self, image: Union[str, Path]) -> List[str]:
assert isinstance(image, str), "The input image must be a string url."
image = "https://popmenucloud.com/cdn-cgi/image/width%3D1920%2Cheight%3D1920%2Cfit%3Dscale-down%2Cformat%3Dauto%2Cquality%3D60/vpylarnm/a6ad1671-8938-457f-b4cd-3215caa122cb.png"
url = "https://www.googleapis.com/customsearch/v1"
api_key = os.getenv("GOOGLE_API_KEY")
search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
assert api_key and search_engine_id, "Please set the GOOGLE_API_KEY and GOOGLE_SEARCH_ENGINE_ID environment variable. See https://developers.google.com/custom-search/v1/using_rest for more information."
params = {
"key": api_key,
"cx": search_engine_id,
"q": image,
"num": 10,
"searchType":"image",
}
response = requests.get(url, params=params)

# Check if the request was successful
if response.status_code != 200:
raise RuntimeError(f"Failed to fetch data: {response.status_code} {response.reason}")

resp = response.json()
items = resp.get("items", [])
return [item["link"] for item in items]


class Add(Tool):
name = "add_"
description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places."
Expand Down

0 comments on commit d5286d1

Please sign in to comment.