Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating new API endpoints #193

Merged
merged 5 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
_LND_API_URL = "https://api.landing.ai/v1/agent"
_LND_API_URL = "https://api.landing.ai/v1/agent/model"
_LND_API_URL_v2 = "https://api.landing.ai/v1/tools"


class ToolCallTrace(BaseModel):
Expand All @@ -27,13 +28,13 @@ class ToolCallTrace(BaseModel):


def send_inference_request(
payload: Dict[str, Any], endpoint_name: str
payload: Dict[str, Any], endpoint_name: str, v2: bool = False
) -> Dict[str, Any]:
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag

url = f"{_LND_API_URL}/model/{endpoint_name}"
url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
if "TOOL_ENDPOINT_URL" in os.environ:
url = os.environ["TOOL_ENDPOINT_URL"]

Expand Down
47 changes: 21 additions & 26 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def owl_v2(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.10,
iou_threshold: float = 0.10,
) -> List[Dict[str, Any]]:
"""'owl_v2' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions. The categories in text prompt
Expand All @@ -138,8 +137,6 @@ def owl_v2(
image (np.ndarray): The image to ground the prompt to.
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.10.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.10.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
Expand All @@ -159,22 +156,22 @@ def owl_v2(
image_size = image.shape[:2]
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"prompts": prompt.split("."),
"image": image_b64,
"tool": "open_vocab_detection",
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
"confidence": box_threshold,
"function_name": "owl_v2",
}
data: Dict[str, Any] = send_inference_request(request_data, "tools")
data: Dict[str, Any] = send_inference_request(request_data, "owlv2", v2=True)
return_data = []
for i in range(len(data["bboxes"])):
return_data.append(
{
"score": round(data["scores"][i], 2),
"label": data["labels"][i].strip(),
"bbox": normalize_bbox(data["bboxes"][i], image_size),
}
)
if data is not None:
for elt in data:
return_data.append(
{
"bbox": normalize_bbox(elt["bbox"], image_size), # type: ignore
"label": elt["label"], # type: ignore
"score": round(elt["score"], 2), # type: ignore
}
)
return return_data


Expand Down Expand Up @@ -367,11 +364,10 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "zero_shot_counting",
"function_name": "loca_zero_shot_counting",
}
resp_data = send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
resp_data = send_inference_request(data, "loca", v2=True)
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
return resp_data


Expand All @@ -397,17 +393,15 @@ def loca_visual_prompt_counting(

image_size = get_image_size(image)
bbox = visual_prompt["bbox"]
bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

data = {
"image": image_b64,
"prompt": bbox_str,
"tool": "few_shot_counting",
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
"function_name": "loca_visual_prompt_counting",
}
resp_data = send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
resp_data = send_inference_request(data, "loca", v2=True)
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
return resp_data


Expand Down Expand Up @@ -603,14 +597,15 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) ->
'This image contains a cat sitting on a table with a bowl of milk.'
"""
image_b64 = convert_to_b64(image)
task = "<MORE_DETAILED_CAPTION>" if detail_caption else "<DETAILED_CAPTION>"
data = {
"image": image_b64,
"tool": "florence2_image_captioning",
"detail_caption": detail_caption,
"task": task,
"function_name": "florencev2_image_caption",
}

answer = send_inference_request(data, "tools")
__import__("ipdb").set_trace()
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
answer = send_inference_request(data, "florence2", v2=True)
return answer["text"][0] # type: ignore


Expand Down
6 changes: 6 additions & 0 deletions vision_agent/utils/type_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class LandingaiAPIKey(BaseSettings):
alias="LANDINGAI_API_KEY",
description="The API key of LandingAI.",
)

dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
api_key_v2: str = Field(
default="land_sk_fnmSzD0ksknSfvhyD8UGu9R4ss3bKfLL1Im5gb6tDQTy2z1Oy5",
alias="LANDINGAI_API_KEY",
description="The API key of LandingAI.",
)

@field_validator("api_key")
@classmethod
Expand Down
Loading