Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating new API endpoints #193

Merged
merged 5 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def test_object_detection():
img = ski.data.coins()
result = florencev2_object_detection(
image=img,
prompt="coin",
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25


def test_template_match():
Expand Down Expand Up @@ -118,7 +119,7 @@ def test_nsfw_classification():
result = vit_nsfw_classification(
image=img,
)
assert result["labels"] == "normal"
assert result["label"] == "normal"


def test_image_caption() -> None:
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
_LND_API_URL = "https://api.landing.ai/v1/agent"
_LND_API_URL = "https://api.landing.ai/v1/agent/model"
_LND_API_URL_v2 = "https://api.landing.ai/v1/tools"


class ToolCallTrace(BaseModel):
Expand All @@ -27,13 +28,13 @@ class ToolCallTrace(BaseModel):


def send_inference_request(
payload: Dict[str, Any], endpoint_name: str
payload: Dict[str, Any], endpoint_name: str, v2: bool = False
) -> Dict[str, Any]:
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag

url = f"{_LND_API_URL}/model/{endpoint_name}"
url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]

Expand Down
108 changes: 51 additions & 57 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def owl_v2(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.10,
iou_threshold: float = 0.10,
) -> List[Dict[str, Any]]:
"""'owl_v2' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions. The categories in text prompt
Expand All @@ -138,8 +137,6 @@ def owl_v2(
image (np.ndarray): The image to ground the prompt to.
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.10.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.10.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
Expand All @@ -159,22 +156,22 @@ def owl_v2(
image_size = image.shape[:2]
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"prompts": prompt.split("."),
"image": image_b64,
"tool": "open_vocab_detection",
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
"confidence": box_threshold,
"function_name": "owl_v2",
}
data: Dict[str, Any] = send_inference_request(request_data, "tools")
data: Dict[str, Any] = send_inference_request(request_data, "owlv2", v2=True)
return_data = []
for i in range(len(data["bboxes"])):
return_data.append(
{
"score": round(data["scores"][i], 2),
"label": data["labels"][i].strip(),
"bbox": normalize_bbox(data["bboxes"][i], image_size),
}
)
if data is not None:
for elt in data:
return_data.append(
{
"bbox": normalize_bbox(elt["bbox"], image_size), # type: ignore
"label": elt["label"], # type: ignore
"score": round(elt["score"], 2), # type: ignore
}
)
return return_data


Expand Down Expand Up @@ -367,11 +364,10 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "zero_shot_counting",
"function_name": "loca_zero_shot_counting",
}
resp_data = send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
resp_data = send_inference_request(data, "loca", v2=True)
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
return resp_data


Expand All @@ -397,17 +393,15 @@ def loca_visual_prompt_counting(

image_size = get_image_size(image)
bbox = visual_prompt["bbox"]
bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

data = {
"image": image_b64,
"prompt": bbox_str,
"tool": "few_shot_counting",
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
"function_name": "loca_visual_prompt_counting",
}
resp_data = send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
resp_data = send_inference_request(data, "loca", v2=True)
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
return resp_data


Expand All @@ -432,13 +426,12 @@ def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"prompt": prompt,
"tool": "image_question_answering_with_context",
"question": prompt,
"function_name": "florencev2_roberta_vqa",
}

answer = send_inference_request(data, "tools")
return answer["text"][0] # type: ignore
answer = send_inference_request(data, "florence2-qa", v2=True)
return answer # type: ignore


def git_vqa_v2(prompt: str, image: np.ndarray) -> str:
Expand Down Expand Up @@ -544,17 +537,16 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
Example
-------
>>> vit_nsfw_classification(image)
{"labels": "normal", "scores": 0.68},
{"label": "normal", "scores": 0.68},
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "nsfw_image_classification",
"function_name": "vit_nsfw_classification",
}
resp_data = send_inference_request(data, "tools")
resp_data["scores"] = round(resp_data["scores"], 4)
resp_data = send_inference_request(data, "nsfw-classification", v2=True)
resp_data["score"] = round(resp_data["score"], 4)
return resp_data


Expand Down Expand Up @@ -603,21 +595,21 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) ->
'This image contains a cat sitting on a table with a bowl of milk.'
"""
image_b64 = convert_to_b64(image)
task = "<MORE_DETAILED_CAPTION>" if detail_caption else "<DETAILED_CAPTION>"
data = {
"image": image_b64,
"tool": "florence2_image_captioning",
"detail_caption": detail_caption,
"task": task,
"function_name": "florencev2_image_caption",
}

answer = send_inference_request(data, "tools")
return answer["text"][0] # type: ignore
answer = send_inference_request(data, "florence2", v2=True)
return answer[task] # type: ignore


def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]:
"""'florencev2_object_detection' is a tool that can detect common objects in an
image without any text prompt or thresholding. It returns a list of detected objects
as labels and their location as bounding boxes.
def florencev2_object_detection(image: np.ndarray, prompt: str) -> List[Dict[str, Any]]:
"""'florencev2_object_detection' is a tool that can detect objects given a text
prompt such as a phrase or class names separated by commas. It returns a list of
detected objects as labels and their location as bounding boxes with score of 1.0.

Parameters:
image (np.ndarray): The image to used to detect objects
Expand All @@ -631,29 +623,30 @@ def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]:

Example
-------
>>> florencev2_object_detection(image)
>>> florencev2_object_detection(image, 'person looking at a coyote')
[
{'score': 1.0, 'label': 'window', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 1.0, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
{'score': 1.0, 'label': 'person', 'bbox': [0.34, 0.21, 0.85, 0.5},
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "object_detection",
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
"prompt": prompt,
"function_name": "florencev2_object_detection",
}

answer = send_inference_request(data, "tools")
detections = send_inference_request(data, "florence2", v2=True)
detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
return_data = []
for i in range(len(answer["bboxes"])):
for i in range(len(detections["bboxes"])):
return_data.append(
{
"score": round(answer["scores"][i], 2),
"label": answer["labels"][i],
"bbox": normalize_bbox(answer["bboxes"][i], image_size),
"score": 1.0,
"label": detections["labels"][i],
"bbox": normalize_bbox(detections["bboxes"][i], image_size),
}
)
return return_data
Expand Down Expand Up @@ -742,13 +735,16 @@ def depth_anything_v2(image: np.ndarray) -> np.ndarray:
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "generate_depth",
"function_name": "depth_anything_v2",
}

answer = send_inference_request(data, "tools")
return_data = np.array(b64_to_pil(answer["masks"][0]).convert("L"))
return return_data
depth_map = send_inference_request(data, "depth-anything-v2", v2=True)
depth_map_np = np.array(depth_map["map"])
depth_map_np = (depth_map_np - depth_map_np.min()) / (
depth_map_np.max() - depth_map_np.min()
)
depth_map_np = (255 * depth_map_np).astype(np.uint8)
return depth_map_np


def generate_soft_edge_image(image: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -839,12 +835,11 @@ def generate_pose_image(image: np.ndarray) -> np.ndarray:
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "generate_pose",
"function_name": "generate_pose_image",
}

answer = send_inference_request(data, "tools")
return_data = np.array(b64_to_pil(answer["masks"][0]).convert("RGB"))
pos_img = send_inference_request(data, "pose-detector", v2=True)
return_data = np.array(b64_to_pil(pos_img["data"]).convert("RGB"))
return return_data


Expand Down Expand Up @@ -1254,7 +1249,6 @@ def overlay_heat_map(
loca_visual_prompt_counting,
florencev2_roberta_vqa,
florencev2_image_caption,
florencev2_object_detection,
detr_segmentation,
depth_anything_v2,
generate_soft_edge_image,
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/utils/type_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LandingaiAPIKey(BaseSettings):
"""

api_key: str = Field(
default="land_sk_fnmSzD0ksknSfvhyD8UGu9R4ss3bKfLL1Im5gb6tDQTy2z1Oy5",
default="land_sk_zKvyPcPV2bVoq7q87KwduoerAxuQpx33DnqP8M1BliOCiZOSoI",
alias="LANDINGAI_API_KEY",
description="The API key of LandingAI.",
)
Expand Down
Loading