Skip to content

Commit

Permalink
added reference mask support
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 18, 2024
1 parent 34580e1 commit 6b232ed
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 6 deletions.
35 changes: 35 additions & 0 deletions examples/mask_app/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import cv2
import numpy as np
import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas

st.title("Image Segmentation Mask App")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
orig_size = image.size

stroke_width = st.sidebar.slider("Stroke width: ", 1, 50, 25)
stroke_color = st.sidebar.color_picker("Stroke color hex: ")

canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="#eee",
background_image=Image.open(uploaded_file) if uploaded_file else None,
update_streamlit=True,
height=500,
drawing_mode="freedraw",
key="canvas",
)

if canvas_result.image_data is not None:
mask = canvas_result.image_data.astype("uint8")[..., 3]
mask[mask > 0] = 255
if st.button("Save Mask Image") and orig_size:
mask = cv2.resize(mask, orig_size, interpolation=cv2.INTER_NEAREST)
cv2.imwrite("mask.png", mask)
st.success("Mask Image saved successfully.")
2 changes: 2 additions & 0 deletions examples/mask_app/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
streamlit
streamlit-drawable-canvas
8 changes: 8 additions & 0 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
"grounding_sam_",
"grounding_dino_",
"extract_frames_",
"dinov_",
]:
continue

Expand Down Expand Up @@ -469,11 +470,18 @@ def chat_with_workflow(
self,
chat: List[Dict[str, str]],
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
) -> Tuple[str, List[Dict]]:
question = chat[0]["content"]
if image:
question += f" Image name: {image}"
if reference_data:
if not ("image" in reference_data and "mask" in reference_data):
raise ValueError(
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
)
question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"

reflections = ""
final_answer = ""
Expand Down
11 changes: 5 additions & 6 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,22 +398,21 @@ class DINOv(Tool):
],
"examples": [
{
"scenario": "Can you find all the balloons in this image that is similar to the provided masked area?",
"scenario": "Can you find all the balloons in this image that is similar to the provided masked area? Image name: input.jpg Reference image: balloon.jpg Reference mask: balloon_mask.jpg",
"parameters": {
"prompt": [
{"mask": "reference_balloon_mask1.jpg", "image": "balloon.jpg"},
{"mask": "reference_balloon_mask2.jpg", "image": "balloon.jpg"},
{"mask": "balloon_mask.jpg", "image": "balloon.jpg"},
],
"image": "input.jpg",
},
},
{
"scenario": "Count all the objects in this image that is similar to the provided masked area? Image name: input.jpg, Reference mask: mask.jpg, Mask image: background.jpg",
"scenario": "Detect all the objects in this image that are similar to the provided mask. Image name: original.jpg Reference image: mask.png Reference mask: background.png",
"parameters": {
"prompt": [
{"mask": "reference_obj_mask1.jpg", "image": "background.jpg"},
{"mask": "mask.png", "image": "background.png"},
],
"image": "input.jpg",
"image": "original.jpg",
},
},
],
Expand Down

0 comments on commit 6b232ed

Please sign in to comment.