Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor tool fixes for Florence OD and Overlay Seg masks for tracking #212

Merged
merged 5 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
detr_segmentation,
dpt_hybrid_midas,
florence2_image_caption,
florence2_object_detection,
florence2_phrase_grounding,
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_owl():

def test_object_detection():
img = ski.data.coins()
result = florence2_object_detection(
result = florence2_phrase_grounding(
image=img,
prompt="coin",
)
Expand Down
60 changes: 39 additions & 21 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,29 +740,14 @@ def chat_with_workflow(
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
self.log_progress(
{
"type": "log",
"log_content": "Creating plans",
"status": "started",
}
)
plans = write_plans(
int_chat,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
),
format_memory(working_memory),
self.planner,

plans = self._create_plans(
int_chat, customized_tool_names, working_memory, self.planner
)

if self.verbosity >= 1:
for p in plans:
# tabulate will fail if the keys are not the same for all elements
p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)
if test_multi_plan:
self._log_plans(plans, self.verbosity)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
Expand Down Expand Up @@ -856,6 +841,39 @@ def log_progress(self, data: Dict[str, Any]) -> None:
if self.report_progress_callback is not None:
self.report_progress_callback(data)

def _create_plans(
self,
int_chat: List[Message],
customized_tool_names: Optional[List[str]],
working_memory: List[Dict[str, str]],
planner: LMM,
) -> Dict[str, Any]:
self.log_progress(
{
"type": "log",
"log_content": "Creating plans",
"status": "started",
}
)
plans = write_plans(
int_chat,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
),
format_memory(working_memory),
planner,
)
return plans

def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
if verbosity >= 1:
for p in plans:
# tabulate will fail if the keys are not the same for all elements
p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)


class AzureVisionAgentCoder(VisionAgentCoder):
"""VisionAgentCoder that uses Azure OpenAI APIs for planning, coding, testing.
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
dpt_hybrid_midas,
extract_frames,
florence2_image_caption,
florence2_object_detection,
florence2_phrase_grounding,
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
Expand Down
28 changes: 17 additions & 11 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,10 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
return answer[task] # type: ignore


def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
"""'florencev2_object_detection' is a tool that can detect and count multiple
objects given a text prompt such as category names or referring expressions. You
can optionally separate the categories in the text with commas. It returns a list
def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
"""'florence2_phrase_grounding' is a tool that can detect multiple
objects given a text prompt which can be object names or caption. You
can optionally separate the object names in the text with commas. It returns a list
of bounding boxes with normalized coordinates, label names and associated
probability scores of 1.0.

Expand All @@ -780,7 +780,7 @@ def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str,

Example
-------
>>> florence2_object_detection('person looking at a coyote', image)
>>> florence2_phrase_grounding('person looking at a coyote', image)
[
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
Expand All @@ -792,7 +792,7 @@ def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str,
"image": image_b64,
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
"prompt": prompt,
"function_name": "florence2_object_detection",
"function_name": "florence2_phrase_grounding",
}

detections = send_inference_request(data, "florence2", v2=True)
Expand Down Expand Up @@ -1418,6 +1418,7 @@ def overlay_segmentation_masks(
medias: Union[np.ndarray, List[np.ndarray]],
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
draw_label: bool = True,
secondary_label_key: str = "tracking_label",
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Expand All @@ -1426,7 +1427,10 @@ def overlay_segmentation_masks(
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks.
dictionaries containing the masks, labels and scores.
draw_label (bool, optional): If True, the labels will be displayed on the image.
secondary_label_key (str, optional): The key to use for the secondary
tracking label which is needed in videos to display tracking information.

Returns:
np.ndarray: The image with the masks displayed.
Expand Down Expand Up @@ -1471,23 +1475,25 @@ def overlay_segmentation_masks(
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
tracking_lbl = elt.get(secondary_label_key, None)
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label:
draw = ImageDraw.Draw(pil_image)
text_box = draw.textbbox((0, 0), text=label, font=font)
text = tracking_lbl if tracking_lbl else label
text_box = draw.textbbox((0, 0), text=text, font=font)
x, y = _get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=label, font=font)
text_box = draw.textbbox((x, y), text=text, font=font)
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
draw.text((x, y), label, fill="black", font=font)
draw.text((x, y), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out

Expand Down Expand Up @@ -1663,7 +1669,7 @@ def florencev2_fine_tuned_object_detection(
florence2_ocr,
florence2_sam2_image,
florence2_sam2_video,
florence2_object_detection,
florence2_phrase_grounding,
ixc25_image_vqa,
ixc25_video_vqa,
detr_segmentation,
Expand Down
Loading