Skip to content

Commit

Permalink
Minor tool fixes for Florence OD and Overlay Seg masks for tracking (#…
Browse files Browse the repository at this point in the history
…212)

* fixed florence OD as phrase grounding

* added code to plot tracking labels which are dynamic

* adding the multplan log only when user invokes multi plan mode

* fix code complexity linting error

* fix mypy issues
  • Loading branch information
shankar-vision-eng authored Aug 29, 2024
1 parent d545395 commit a8e4b62
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 35 deletions.
4 changes: 2 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
detr_segmentation,
dpt_hybrid_midas,
florence2_image_caption,
florence2_object_detection,
florence2_phrase_grounding,
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_owl():

def test_object_detection():
img = ski.data.coins()
result = florence2_object_detection(
result = florence2_phrase_grounding(
image=img,
prompt="coin",
)
Expand Down
60 changes: 39 additions & 21 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,29 +744,14 @@ def chat_with_workflow(
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
self.log_progress(
{
"type": "log",
"log_content": "Creating plans",
"status": "started",
}
)
plans = write_plans(
int_chat,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
),
format_memory(working_memory),
self.planner,

plans = self._create_plans(
int_chat, customized_tool_names, working_memory, self.planner
)

if self.verbosity >= 1:
for p in plans:
# tabulate will fail if the keys are not the same for all elements
p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)
if test_multi_plan:
self._log_plans(plans, self.verbosity)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
Expand Down Expand Up @@ -860,6 +845,39 @@ def log_progress(self, data: Dict[str, Any]) -> None:
if self.report_progress_callback is not None:
self.report_progress_callback(data)

def _create_plans(
self,
int_chat: List[Message],
customized_tool_names: Optional[List[str]],
working_memory: List[Dict[str, str]],
planner: LMM,
) -> Dict[str, Any]:
self.log_progress(
{
"type": "log",
"log_content": "Creating plans",
"status": "started",
}
)
plans = write_plans(
int_chat,
T.get_tool_descriptions_by_names(
customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
),
format_memory(working_memory),
planner,
)
return plans

def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
if verbosity >= 1:
for p in plans:
# tabulate will fail if the keys are not the same for all elements
p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)


class OllamaVisionAgentCoder(VisionAgentCoder):
"""VisionAgentCoder that uses Ollama models for planning, coding, testing.
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
dpt_hybrid_midas,
extract_frames,
florence2_image_caption,
florence2_object_detection,
florence2_phrase_grounding,
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
Expand Down
28 changes: 17 additions & 11 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,10 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
return answer[task] # type: ignore


def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
"""'florencev2_object_detection' is a tool that can detect and count multiple
objects given a text prompt such as category names or referring expressions. You
can optionally separate the categories in the text with commas. It returns a list
def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
"""'florence2_phrase_grounding' is a tool that can detect multiple
objects given a text prompt which can be object names or caption. You
can optionally separate the object names in the text with commas. It returns a list
of bounding boxes with normalized coordinates, label names and associated
probability scores of 1.0.
Expand All @@ -780,7 +780,7 @@ def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str,
Example
-------
>>> florence2_object_detection('person looking at a coyote', image)
>>> florence2_phrase_grounding('person looking at a coyote', image)
[
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
Expand All @@ -792,7 +792,7 @@ def florence2_object_detection(prompt: str, image: np.ndarray) -> List[Dict[str,
"image": image_b64,
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
"prompt": prompt,
"function_name": "florence2_object_detection",
"function_name": "florence2_phrase_grounding",
}

detections = send_inference_request(data, "florence2", v2=True)
Expand Down Expand Up @@ -1418,6 +1418,7 @@ def overlay_segmentation_masks(
medias: Union[np.ndarray, List[np.ndarray]],
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
draw_label: bool = True,
secondary_label_key: str = "tracking_label",
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Expand All @@ -1426,7 +1427,10 @@ def overlay_segmentation_masks(
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks.
dictionaries containing the masks, labels and scores.
draw_label (bool, optional): If True, the labels will be displayed on the image.
secondary_label_key (str, optional): The key to use for the secondary
tracking label which is needed in videos to display tracking information.
Returns:
np.ndarray: The image with the masks displayed.
Expand Down Expand Up @@ -1471,23 +1475,25 @@ def overlay_segmentation_masks(
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
tracking_lbl = elt.get(secondary_label_key, None)
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label:
draw = ImageDraw.Draw(pil_image)
text_box = draw.textbbox((0, 0), text=label, font=font)
text = tracking_lbl if tracking_lbl else label
text_box = draw.textbbox((0, 0), text=text, font=font)
x, y = _get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=label, font=font)
text_box = draw.textbbox((x, y), text=text, font=font)
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
draw.text((x, y), label, fill="black", font=font)
draw.text((x, y), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out

Expand Down Expand Up @@ -1663,7 +1669,7 @@ def florencev2_fine_tuned_object_detection(
florence2_ocr,
florence2_sam2_image,
florence2_sam2_video,
florence2_object_detection,
florence2_phrase_grounding,
ixc25_image_vqa,
ixc25_video_vqa,
detr_segmentation,
Expand Down

0 comments on commit a8e4b62

Please sign in to comment.