diff --git a/examples/mask_app/app.py b/examples/mask_app/app.py new file mode 100644 index 00000000..23a5fc78 --- /dev/null +++ b/examples/mask_app/app.py @@ -0,0 +1,34 @@ +import cv2 +import streamlit as st +from PIL import Image +from streamlit_drawable_canvas import st_canvas + +st.title("Image Segmentation Mask App") + +uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"]) +if uploaded_file is not None: + image = Image.open(uploaded_file) + orig_size = image.size + +stroke_width = st.sidebar.slider("Stroke width: ", 1, 50, 25) +stroke_color = st.sidebar.color_picker("Stroke color hex: ") + +canvas_result = st_canvas( + fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity + stroke_width=stroke_width, + stroke_color=stroke_color, + background_color="#eee", + background_image=Image.open(uploaded_file) if uploaded_file else None, + update_streamlit=True, + height=500, + drawing_mode="freedraw", + key="canvas", +) + +if canvas_result.image_data is not None: + mask = canvas_result.image_data.astype("uint8")[..., 3] + mask[mask > 0] = 255 + if st.button("Save Mask Image") and orig_size: + mask = cv2.resize(mask, orig_size, interpolation=cv2.INTER_NEAREST) + cv2.imwrite("mask.png", mask) + st.success("Mask Image saved successfully.") diff --git a/examples/mask_app/requirements.txt b/examples/mask_app/requirements.txt new file mode 100644 index 00000000..3ce2aea0 --- /dev/null +++ b/examples/mask_app/requirements.txt @@ -0,0 +1,2 @@ +streamlit +streamlit-drawable-canvas diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index a3f09b82..44c3aa08 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -365,6 +365,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] "grounding_sam_", "grounding_dino_", "extract_frames_", + "dinov_", ]: continue @@ -469,11 +470,18 @@ def chat_with_workflow( self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None, + reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, ) -> Tuple[str, List[Dict]]: question = chat[0]["content"] if image: question += f" Image name: {image}" + if reference_data: + if not ("image" in reference_data and "mask" in reference_data): + raise ValueError( + f"Reference data must contain 'image' and 'mask'. but got {reference_data}" + ) + question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}" reflections = "" final_answer = "" diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index d0164e11..f36a2033 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -103,7 +103,9 @@ def overlay_bboxes( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])} + color = { + label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"])) + } width, height = image.size fontsize = max(12, int(min(width, height) / 40)) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index aa81d16c..38bb08d4 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -6,6 +6,7 @@ BboxIoU, BoxDistance, Crop, + DINOv, ExtractFrames, GroundingDINO, GroundingSAM, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index ef40480c..40728e62 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -372,6 +372,104 @@ def __call__( return ret_pred +class DINOv(Tool): + r"""DINOv is a tool that can detect and segment similar objects with the given input masks. + + Example + ------- + >>> import vision_agent as va + >>> t = va.tools.DINOv() + >>> t(prompt=[{"mask":"balloon_mask.jpg", "image": "balloon.jpg"}], image="balloon.jpg"]) + [{'scores': [0.512, 0.212], + 'masks': [array([[0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, + array([[0, 0, 0, ..., 0, 0, 0], + ..., + [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] + """ + + name = "dinov_" + description = "'dinov_' is a tool that can detect and segment similar objects given a reference segmentation mask." + usage = { + "required_parameters": [ + {"name": "prompt", "type": "List[Dict[str, str]]"}, + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you find all the balloons in this image that is similar to the provided masked area? Image name: input.jpg Reference image: balloon.jpg Reference mask: balloon_mask.jpg", + "parameters": { + "prompt": [ + {"mask": "balloon_mask.jpg", "image": "balloon.jpg"}, + ], + "image": "input.jpg", + }, + }, + { + "scenario": "Detect all the objects in this image that are similar to the provided mask. Image name: original.jpg Reference image: mask.png Reference mask: background.png", + "parameters": { + "prompt": [ + {"mask": "mask.png", "image": "background.png"}, + ], + "image": "original.jpg", + }, + }, + ], + } + + def __call__( + self, prompt: List[Dict[str, str]], image: Union[str, ImageType] + ) -> Dict: + """Invoke the DINOv model. + + Parameters: + prompt: a list of visual prompts in the form of {'mask': 'MASK_FILE_PATH', 'image': 'IMAGE_FILE_PATH'}. + image: the input image to segment. + + Returns: + A dictionary of the below keys: 'scores', 'masks' and 'mask_shape', which stores a list of detected segmentation masks and its scores. + """ + image_b64 = convert_to_b64(image) + for p in prompt: + p["mask"] = convert_to_b64(p["mask"]) + p["image"] = convert_to_b64(p["image"]) + request_data = { + "prompt": prompt, + "image": image_b64, + "tool": "dinov", + } + data: Dict[str, Any] = _send_inference_request(request_data, "dinov") + if "bboxes" in data: + data["bboxes"] = [ + normalize_bbox(box, data["mask_shape"]) for box in data["bboxes"] + ] + if "masks" in data: + data["masks"] = [ + rle_decode(mask_rle=mask, shape=data["mask_shape"]) + for mask in data["masks"] + ] + data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))] + return data + + +class AgentDINOv(DINOv): + def __call__( + self, + prompt: List[Dict[str, str]], + image: Union[str, ImageType], + ) -> Dict: + rets = super().__call__(prompt, image) + mask_files = [] + for mask in rets["masks"]: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + file_name = Path(tmp.name).with_suffix(".mask.png") + Image.fromarray(mask * 255).save(file_name) + mask_files.append(str(file_name)) + rets["masks"] = mask_files + return rets + + class AgentGroundingSAM(GroundingSAM): r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files returns the file name. This makes it easier for agents to use. @@ -775,6 +873,7 @@ def __call__(self, equation: str) -> float: AgentGroundingSAM, ZeroShotCounting, VisualPromptCounting, + AgentDINOv, ExtractFrames, Crop, BboxArea,