Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Product Changes & Bug Fixes #262

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_owl_v2_image():
)
assert 24 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_fine_tune_id():
Expand All @@ -80,6 +81,7 @@ def test_owl_v2_fine_tune_id():
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_video():
Expand All @@ -93,6 +95,7 @@ def test_owl_v2_video():

assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_phrase_grounding():
Expand All @@ -101,8 +104,10 @@ def test_florence2_phrase_grounding():
image=img,
prompt="coin",
)

assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_fine_tune_id():
Expand All @@ -115,6 +120,7 @@ def test_florence2_phrase_grounding_fine_tune_id():
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_video():
Expand All @@ -127,6 +133,7 @@ def test_florence2_phrase_grounding_video():
)
assert len(result) == 10
assert 2 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_phrase_grounding_video_fine_tune_id():
Expand All @@ -141,6 +148,7 @@ def test_florence2_phrase_grounding_video_fine_tune_id():
)
assert len(result) == 10
assert 16 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_template_match():
Expand Down Expand Up @@ -395,6 +403,7 @@ def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_countgd_example_based_counting() -> None:
Expand All @@ -404,3 +413,4 @@ def test_countgd_example_based_counting() -> None:
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
55 changes: 51 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,32 @@ def execute_user_code_action(
return user_result, user_obs


def add_step_descriptions(response: Dict[str, str]) -> Dict[str, str]:
response = copy.deepcopy(response)
if "response" in response:
resp_str = response["response"]
if "<execute_python>" in resp_str:
# only include descriptions for these, the rest will just have executing
# code
description_map = {
"open_code_artifact": "Reading file.",
"create_code_artifact": "Creating file.",
"edit_code_artifact": "Editing file.",
"generate_vision_code": "Generating vision code.",
"edit_vision_code": "Editing vision code.",
}
description = ""
for k, v in description_map.items():
if k in resp_str:
description += v + " "
if description == "":
description = "Executing code."
resp_str = resp_str[resp_str.find("<execute_python>") :]
resp_str = description + resp_str
response["response"] = resp_str
return response


class VisionAgent(Agent):
"""Vision Agent is an agent that can chat with the user and call tools or other
agents to generate code for it. Vision Agent uses python code to execute actions
Expand Down Expand Up @@ -335,8 +361,18 @@ def chat_with_code(
response = run_conversation(self.agent, int_chat)
if self.verbosity >= 1:
_LOGGER.info(response)
int_chat.append({"role": "assistant", "content": str(response)})
orig_chat.append({"role": "assistant", "content": str(response)})
int_chat.append(
{
"role": "assistant",
"content": str(add_step_descriptions(response)),
}
)
orig_chat.append(
{
"role": "assistant",
"content": str(add_step_descriptions(response)),
}
)

# sometimes it gets stuck in a loop, so we force it to exit
if last_response == response:
Expand Down Expand Up @@ -382,6 +418,16 @@ def chat_with_code(

obs_chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
# for view_media_artifact, we need to ensure the media is loaded
# locally so the conversation agent can actually see it
code_interpreter.download_file(
str(remote_artifacts_path.name),
str(self.local_artifacts_path),
)
artifacts.load(
self.local_artifacts_path,
Path(self.local_artifacts_path).parent,
)
obs_chat_elt["media"] = [
Path(self.local_artifacts_path).parent / media_ob
for media_ob in media_obs
Expand All @@ -407,8 +453,9 @@ def chat_with_code(
code_interpreter.download_file(
str(remote_artifacts_path.name), str(self.local_artifacts_path)
)
artifacts.load(self.local_artifacts_path)
artifacts.save()
artifacts.load(
self.local_artifacts_path, Path(self.local_artifacts_path).parent
)
return orig_chat, artifacts

def streaming_message(self, message: Dict[str, Any]) -> None:
Expand Down
19 changes: 13 additions & 6 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,26 @@ def __init__(self, remote_save_path: Union[str, Path]) -> None:

self.code_sandbox_runtime = None

def load(self, file_path: Union[str, Path]) -> None:
"""Loads are artifacts into the remote environment. If an artifact value is None
it will skip loading it.
def load(
self,
artifacts_path: Union[str, Path],
load_to: Optional[Union[str, Path]] = None,
) -> None:
"""Loads are artifacts into the load_to path. If load_to is None, it will load
into remote_save_path. If an artifact value is None it will skip loading it.

Parameters:
file_path (Union[str, Path]): The file path to load the artifacts from
artifacts_path (Union[str, Path]): The file path to load the artifacts from
"""
with open(file_path, "rb") as f:
with open(artifacts_path, "rb") as f:
self.artifacts = pkl.load(f)

load_to = self.remote_save_path.parent if load_to is None else Path(load_to)

for k, v in self.artifacts.items():
if v is not None:
mode = "w" if isinstance(v, str) else "wb"
with open(self.remote_save_path.parent / k, mode) as f:
with open(load_to / k, mode) as f:
f.write(v)

def show(self, uploaded_file_path: Optional[Union[str, Path]] = None) -> str:
Expand Down
6 changes: 4 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def countgd_counting(
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
]
"""
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
prompt = prompt.replace(", ", " .")
Expand All @@ -712,7 +713,7 @@ def countgd_counting(
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=normalize_bbox(bbox["bounding_box"], image_size),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
Expand Down Expand Up @@ -757,6 +758,7 @@ def countgd_example_based_counting(
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
]
"""
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
visual_prompts = [
Expand All @@ -771,7 +773,7 @@ def countgd_example_based_counting(
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=normalize_bbox(bbox["bounding_box"], image_size),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
Expand Down
Loading