Skip to content

Commit

Permalink
Product Changes & Bug Fixes (#262)
Browse files Browse the repository at this point in the history
* added more test cases for correct format, fixed normalize bboxes for countgd

* added more functionality to artifacts

* added premade responses to execute code
  • Loading branch information
dillonalaird authored Oct 9, 2024
1 parent 0e6d8c3 commit a3894d6
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 12 deletions.
10 changes: 10 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_owl_v2_image():
)
assert 24 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_fine_tune_id():
Expand All @@ -80,6 +81,7 @@ def test_owl_v2_fine_tune_id():
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_video():
Expand All @@ -93,6 +95,7 @@ def test_owl_v2_video():

assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_phrase_grounding():
Expand All @@ -101,8 +104,10 @@ def test_florence2_phrase_grounding():
image=img,
prompt="coin",
)

assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_fine_tune_id():
Expand All @@ -115,6 +120,7 @@ def test_florence2_phrase_grounding_fine_tune_id():
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_video():
Expand All @@ -127,6 +133,7 @@ def test_florence2_phrase_grounding_video():
)
assert len(result) == 10
assert 2 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_phrase_grounding_video_fine_tune_id():
Expand All @@ -141,6 +148,7 @@ def test_florence2_phrase_grounding_video_fine_tune_id():
)
assert len(result) == 10
assert 16 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_template_match():
Expand Down Expand Up @@ -395,6 +403,7 @@ def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_countgd_example_based_counting() -> None:
Expand All @@ -404,3 +413,4 @@ def test_countgd_example_based_counting() -> None:
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
55 changes: 51 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,32 @@ def execute_user_code_action(
return user_result, user_obs


def add_step_descriptions(response: Dict[str, str]) -> Dict[str, str]:
response = copy.deepcopy(response)
if "response" in response:
resp_str = response["response"]
if "<execute_python>" in resp_str:
# only include descriptions for these, the rest will just have executing
# code
description_map = {
"open_code_artifact": "Reading file.",
"create_code_artifact": "Creating file.",
"edit_code_artifact": "Editing file.",
"generate_vision_code": "Generating vision code.",
"edit_vision_code": "Editing vision code.",
}
description = ""
for k, v in description_map.items():
if k in resp_str:
description += v + " "
if description == "":
description = "Executing code."
resp_str = resp_str[resp_str.find("<execute_python>") :]
resp_str = description + resp_str
response["response"] = resp_str
return response


class VisionAgent(Agent):
"""Vision Agent is an agent that can chat with the user and call tools or other
agents to generate code for it. Vision Agent uses python code to execute actions
Expand Down Expand Up @@ -335,8 +361,18 @@ def chat_with_code(
response = run_conversation(self.agent, int_chat)
if self.verbosity >= 1:
_LOGGER.info(response)
int_chat.append({"role": "assistant", "content": str(response)})
orig_chat.append({"role": "assistant", "content": str(response)})
int_chat.append(
{
"role": "assistant",
"content": str(add_step_descriptions(response)),
}
)
orig_chat.append(
{
"role": "assistant",
"content": str(add_step_descriptions(response)),
}
)

# sometimes it gets stuck in a loop, so we force it to exit
if last_response == response:
Expand Down Expand Up @@ -382,6 +418,16 @@ def chat_with_code(

obs_chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
# for view_media_artifact, we need to ensure the media is loaded
# locally so the conversation agent can actually see it
code_interpreter.download_file(
str(remote_artifacts_path.name),
str(self.local_artifacts_path),
)
artifacts.load(
self.local_artifacts_path,
Path(self.local_artifacts_path).parent,
)
obs_chat_elt["media"] = [
Path(self.local_artifacts_path).parent / media_ob
for media_ob in media_obs
Expand All @@ -407,8 +453,9 @@ def chat_with_code(
code_interpreter.download_file(
str(remote_artifacts_path.name), str(self.local_artifacts_path)
)
artifacts.load(self.local_artifacts_path)
artifacts.save()
artifacts.load(
self.local_artifacts_path, Path(self.local_artifacts_path).parent
)
return orig_chat, artifacts

def streaming_message(self, message: Dict[str, Any]) -> None:
Expand Down
19 changes: 13 additions & 6 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,26 @@ def __init__(self, remote_save_path: Union[str, Path]) -> None:

self.code_sandbox_runtime = None

def load(self, file_path: Union[str, Path]) -> None:
"""Loads are artifacts into the remote environment. If an artifact value is None
it will skip loading it.
def load(
self,
artifacts_path: Union[str, Path],
load_to: Optional[Union[str, Path]] = None,
) -> None:
"""Loads are artifacts into the load_to path. If load_to is None, it will load
into remote_save_path. If an artifact value is None it will skip loading it.
Parameters:
file_path (Union[str, Path]): The file path to load the artifacts from
artifacts_path (Union[str, Path]): The file path to load the artifacts from
"""
with open(file_path, "rb") as f:
with open(artifacts_path, "rb") as f:
self.artifacts = pkl.load(f)

load_to = self.remote_save_path.parent if load_to is None else Path(load_to)

for k, v in self.artifacts.items():
if v is not None:
mode = "w" if isinstance(v, str) else "wb"
with open(self.remote_save_path.parent / k, mode) as f:
with open(load_to / k, mode) as f:
f.write(v)

def show(self, uploaded_file_path: Optional[Union[str, Path]] = None) -> str:
Expand Down
6 changes: 4 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def countgd_counting(
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
]
"""
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
prompt = prompt.replace(", ", " .")
Expand All @@ -712,7 +713,7 @@ def countgd_counting(
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=normalize_bbox(bbox["bounding_box"], image_size),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
Expand Down Expand Up @@ -757,6 +758,7 @@ def countgd_example_based_counting(
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
]
"""
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
visual_prompts = [
Expand All @@ -771,7 +773,7 @@ def countgd_example_based_counting(
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=normalize_bbox(bbox["bounding_box"], image_size),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
Expand Down

0 comments on commit a3894d6

Please sign in to comment.