Skip to content

Commit

Permalink
merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
MingruiZhang committed Aug 26, 2024
2 parents 3385c7b + 7e149d1 commit 8476e5e
Show file tree
Hide file tree
Showing 18 changed files with 798 additions and 253 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.2.110"
version = "0.2.112"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -78,6 +78,8 @@ line_length = 88
profile = "black"

[tool.mypy]
plugins = "pydantic.mypy"

exclude = "tests"
show_error_context = true
pretty = true
Expand Down
67 changes: 58 additions & 9 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import skimage as ski
from PIL import Image

from vision_agent.tools import (
blip_image_caption,
Expand All @@ -8,15 +9,19 @@
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
florencev2_image_caption,
florencev2_object_detection,
florencev2_roberta_vqa,
florencev2_ocr,
florence2_image_caption,
florence2_object_detection,
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
generate_pose_image,
generate_soft_edge_image,
git_vqa_v2,
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
Expand Down Expand Up @@ -60,7 +65,7 @@ def test_owl():

def test_object_detection():
img = ski.data.coins()
result = florencev2_object_detection(
result = florence2_object_detection(
image=img,
prompt="coin",
)
Expand Down Expand Up @@ -88,6 +93,30 @@ def test_grounding_sam():
assert len([res["mask"] for res in result]) == 24


def test_florence2_sam2_image():
img = ski.data.coins()
result = florence2_sam2_image(
prompt="coin",
image=img,
)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
assert len([res["mask"] for res in result]) == 25


def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = florence2_sam2_video(
prompt="coin",
frames=frames,
)
assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 25
assert len([res["mask"] for res in result[0]]) == 25


def test_segmentation():
img = ski.data.coins()
result = detr_segmentation(
Expand Down Expand Up @@ -133,7 +162,7 @@ def test_image_caption() -> None:

def test_florence_image_caption() -> None:
img = ski.data.rocket()
result = florencev2_image_caption(
result = florence2_image_caption(
image=img,
)
assert "The image shows a rocket on a launch pad at night" in result.strip()
Expand Down Expand Up @@ -168,13 +197,33 @@ def test_git_vqa_v2() -> None:

def test_image_qa_with_context() -> None:
img = ski.data.rocket()
result = florencev2_roberta_vqa(
result = florence2_roberta_vqa(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert "night" in result.strip()


def test_ixc25_image_vqa() -> None:
img = ski.data.cat()
result = ixc25_image_vqa(
prompt="What animal is in this image?",
image=img,
)
assert "cat" in result.strip()


def test_ixc25_video_vqa() -> None:
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
result = ixc25_video_vqa(
prompt="What animal is in this video?",
frames=frames,
)
assert "cat" in result.strip()


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
Expand All @@ -183,9 +232,9 @@ def test_ocr() -> None:
assert any("Region-based segmentation" in res["label"] for res in result)


def test_florencev2_ocr() -> None:
def test_florence2_ocr() -> None:
img = ski.data.page()
result = florencev2_ocr(
result = florence2_ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)
Expand Down
11 changes: 3 additions & 8 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@
from typing import Any, Dict

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)


def extract_json(json_str: str) -> Dict[str, Any]:
try:
json_str = json_str.replace("\n", " ")
json_dict = json.loads(json_str)
except json.JSONDecodeError:
input_json_str = json_str
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
error_msg = f"Could not extract JSON from the given str: {json_str}.\nFunction input:\n{input_json_str}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e

json_dict = json.loads(json_str)
return json_dict # type: ignore


Expand Down
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DefaultImports:
code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions, florencev2_fine_tuning",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions",
]

@staticmethod
Expand Down
48 changes: 28 additions & 20 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
import tempfile
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast

Expand Down Expand Up @@ -86,8 +87,8 @@ def format_memory(memory: List[Dict[str, str]]) -> str:
def format_plans(plans: Dict[str, Any]) -> str:
plan_str = ""
for k, v in plans.items():
plan_str += f"{k}:\n"
plan_str += "-" + "\n-".join([e["instructions"] for e in v])
plan_str += "\n" + f"{k}: {v['thoughts']}\n"
plan_str += " -" + "\n -".join([e for e in v["instructions"]])

return plan_str

Expand Down Expand Up @@ -232,13 +233,11 @@ def pick_plan(
"status": "completed" if tool_output.success else "failed",
}
)
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]
tool_output_str = tool_output.text().strip()

if verbosity == 2:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempte {count}")
_LOGGER.info(f"Code execution result after attempt {count}")

count += 1

Expand All @@ -255,7 +254,21 @@ def pick_plan(
tool_output=tool_output_str[:20_000],
)
chat[-1]["content"] = prompt
best_plan = extract_json(model(chat, stream=False)) # type: ignore

count = 0
best_plan = None
while best_plan is None and count < max_retries:
try:
best_plan = extract_json(model(chat, stream=False)) # type: ignore
except JSONDecodeError as e:
_LOGGER.exception(
f"Error while extracting JSON during picking best plan {str(e)}"
)
pass
count += 1

if best_plan is None:
best_plan = {"best_plan": list(plans.keys())[0]}

if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
Expand Down Expand Up @@ -529,7 +542,7 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:


def retrieve_tools(
plans: Dict[str, List[Dict[str, str]]],
plans: Dict[str, Dict[str, Any]],
tool_recommender: Sim,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
Expand All @@ -546,8 +559,8 @@ def retrieve_tools(
tool_lists: Dict[str, List[Dict[str, str]]] = {}
for k, plan in plans.items():
tool_lists[k] = []
for task in plan:
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
for task in plan["instructions"]:
tools = tool_recommender.top_k(task, k=2, thresh=0.3)
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_lists[k].extend(
Expand Down Expand Up @@ -746,14 +759,7 @@ def chat_with_workflow(
if self.verbosity >= 1:
for p in plans:
# tabulate will fail if the keys are not the same for all elements
p_fixed = [
{
"instructions": (
e["instructions"] if "instructions" in e else ""
)
}
for e in plans[p]
]
p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
_LOGGER.info(
f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)
Expand Down Expand Up @@ -802,13 +808,15 @@ def chat_with_workflow(
)

if self.verbosity >= 1:
plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]]
_LOGGER.info(
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

results = write_and_test_code(
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
plan=f"\n{plan_i['thoughts']}\n-"
+ "\n-".join([e for e in plan_i["instructions"]]),
tool_info=tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
Expand Down
16 changes: 9 additions & 7 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,19 @@
**Instructions**:
1. Based on the context and tools you have available, create a plan of subtasks to achieve the user request.
2. Output three different plans each utilize a different strategy or tool.
2. Output three different plans each utilize a different strategy or set of tools.
Output a list of jsons in the following format
```json
{{
"plan1":
[
{{
"instructions": str # what you should do in this task associated with a tool
}}
],
{{
"thoughts": str # your thought process for choosing this plan
"instructions": [
str # what you should do in this task associated with a tool
]
}},
"plan2": ...,
"plan3": ...
}}
Expand Down Expand Up @@ -127,7 +128,8 @@
**Instructions**:
1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request.
2. Output a JSON object with the following format:
2. Try solving the problem yourself given the image and pick the plan that matches your solution the best.
3. Output a JSON object with the following format:
{{
"thoughts": str # your thought process for choosing the best plan
"best_plan": str # the best plan you have chosen
Expand Down
Loading

0 comments on commit 8476e5e

Please sign in to comment.