Skip to content

Commit

Permalink
Merge branch 'main' into update-comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird authored Apr 22, 2024
2 parents b373a74 + 501b95d commit 0d3849a
Show file tree
Hide file tree
Showing 18 changed files with 342 additions and 1,149 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
Many current vision problems can easily take hours or days to solve, you need to find the
right model, figure out how to use it, possibly write programming logic around it to
right model, figure out how to use it, possibly write programming logic around it to
accomplish the task you want or even more expensive, train your own model. Vision Agent
aims to provide an in-seconds experience by allowing users to describe their problem in
text and utilizing agent frameworks to solve the task for them. Check out our discord
Expand Down Expand Up @@ -110,6 +110,9 @@ you. For example:
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
| BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.
Expand Down
3 changes: 0 additions & 3 deletions docs/api/data.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/emb.md

This file was deleted.

2 changes: 0 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,5 @@ nav:
- vision_agent.tools: api/tools.md
- vision_agent.llm: api/llm.md
- vision_agent.lmm: api/lmm.md
- vision_agent.data: api/data.md
- vision_agent.emb: api/emb.md
- vision_agent.image_utils: api/image_utils.md
- Old documentation: old.md
793 changes: 5 additions & 788 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.1.4"
version = "0.2.1"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand All @@ -16,16 +16,12 @@ packages = [{include = "vision_agent"}]
"documentation" = "https://github.com/landing-ai/vision-agent"

[tool.poetry.dependencies] # main dependency group
python = ">=3.9,<3.12"

python = ">=3.9"
numpy = ">=1.21.0,<2.0.0"
pillow = "10.*"
requests = "2.*"
tqdm = ">=4.64.0,<5.0.0"
pandas = "2.*"
faiss-cpu = "1.*"
torch = "2.1.*" # 2.2 causes sentence-transformers to seg fault
sentence-transformers = "2.*"
openai = "1.*"
typing_extensions = "4.*"
moviepy = "1.*"
Expand Down
93 changes: 0 additions & 93 deletions tests/test_data.py

This file was deleted.

2 changes: 0 additions & 2 deletions vision_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from .agent import Agent
from .data import DataStore, build_data_store
from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder
from .llm import LLM, OpenAILLM
from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm
48 changes: 33 additions & 15 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand All @@ -33,6 +33,7 @@

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80


def parse_json(s: str) -> Any:
Expand Down Expand Up @@ -335,7 +336,9 @@ def _handle_viz_tools(

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
if not isinstance(call_result, dict) or (
"bboxes" not in call_result and "masks" not in call_result
):
return image_to_data

# if the call was successful, then we can add the image data
Expand All @@ -348,11 +351,12 @@ def _handle_viz_tools(
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])
image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
image_to_data[image]["labels"].extend(call_result.get("labels", []))
image_to_data[image]["scores"].extend(call_result.get("scores", []))
image_to_data[image]["masks"].extend(call_result.get("masks", []))
if "mask_shape" in call_result:
image_to_data[image]["mask_shape"] = call_result["mask_shape"]

return image_to_data

Expand All @@ -366,6 +370,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
"grounding_dino_",
"extract_frames_",
"dinov_",
"zero_shot_counting_",
"visual_prompt_counting_",
]:
continue

Expand All @@ -378,8 +384,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
for image_str in image_to_data:
image_path = Path(image_str)
image_data = image_to_data[image_str]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
if "_counting_" in tool_result["tool_name"]:
image = overlay_heat_map(image_path, image_data)
else:
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
visualized_images.append(f.name)
Expand Down Expand Up @@ -498,11 +507,21 @@ def chat_with_workflow(
if image:
question += f" Image name: {image}"
if reference_data:
if not ("image" in reference_data and "mask" in reference_data):
if not (
"image" in reference_data
and ("mask" in reference_data or "bbox" in reference_data)
):
raise ValueError(
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
)
question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
visual_prompt_data = (
f"Reference mask: {reference_data['mask']}"
if "mask" in reference_data
else f"Reference bbox: {reference_data['bbox']}"
)
question += (
f" Reference image: {reference_data['image']}, {visual_prompt_data}"
)

reflections = ""
final_answer = ""
Expand Down Expand Up @@ -545,7 +564,6 @@ def chat_with_workflow(
final_answer = answer_summarize(
self.answer_model, question, answers, reflections
)

visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
if len(visualized_output) > 0:
Expand Down Expand Up @@ -629,7 +647,7 @@ def retrieval(

self.log_progress(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
Expand Down Expand Up @@ -675,6 +693,6 @@ def create_tasks(
task_list = []
self.log_progress(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
{tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
return task_list
1 change: 0 additions & 1 deletion vision_agent/data/__init__.py

This file was deleted.

Loading

0 comments on commit 0d3849a

Please sign in to comment.