From a3894d66bdad1f73535bc9111c8b016df7a29fdb Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 9 Oct 2024 00:20:37 -0700 Subject: [PATCH] Product Changes & Bug Fixes (#262) * added more test cases for correct format, fixed normalize bboxes for countgd * added more functionality to artifacts * added premade responses to execute code --- tests/integ/test_tools.py | 10 ++++++ vision_agent/agent/vision_agent.py | 55 +++++++++++++++++++++++++++--- vision_agent/tools/meta_tools.py | 19 +++++++---- vision_agent/tools/tools.py | 6 ++-- 4 files changed, 78 insertions(+), 12 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index fffd379d..4f5c674f 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -68,6 +68,7 @@ def test_owl_v2_image(): ) assert 24 <= len(result) <= 26 assert [res["label"] for res in result] == ["coin"] * len(result) + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_owl_v2_fine_tune_id(): @@ -80,6 +81,7 @@ def test_owl_v2_fine_tune_id(): # this calls a fine-tuned florence2 model which is going to be worse at this task assert 14 <= len(result) <= 26 assert [res["label"] for res in result] == ["coin"] * len(result) + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_owl_v2_video(): @@ -93,6 +95,7 @@ def test_owl_v2_video(): assert len(result) == 10 assert 24 <= len([res["label"] for res in result[0]]) <= 26 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) def test_florence2_phrase_grounding(): @@ -101,8 +104,10 @@ def test_florence2_phrase_grounding(): image=img, prompt="coin", ) + assert len(result) == 25 assert [res["label"] for res in result] == ["coin"] * 25 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_florence2_phrase_grounding_fine_tune_id(): @@ -115,6 +120,7 @@ def test_florence2_phrase_grounding_fine_tune_id(): # this calls a fine-tuned florence2 model which is going to be worse at this task assert 14 <= len(result) <= 26 assert [res["label"] for res in result] == ["coin"] * len(result) + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_florence2_phrase_grounding_video(): @@ -127,6 +133,7 @@ def test_florence2_phrase_grounding_video(): ) assert len(result) == 10 assert 2 <= len([res["label"] for res in result[0]]) <= 26 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) def test_florence2_phrase_grounding_video_fine_tune_id(): @@ -141,6 +148,7 @@ def test_florence2_phrase_grounding_video_fine_tune_id(): ) assert len(result) == 10 assert 16 <= len([res["label"] for res in result[0]]) <= 26 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) def test_template_match(): @@ -395,6 +403,7 @@ def test_countgd_counting() -> None: img = ski.data.coins() result = countgd_counting(image=img, prompt="coin") assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 def test_countgd_example_based_counting() -> None: @@ -404,3 +413,4 @@ def test_countgd_example_based_counting() -> None: image=img, ) assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 3fd38df1..ba6e1d64 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -149,6 +149,32 @@ def execute_user_code_action( return user_result, user_obs +def add_step_descriptions(response: Dict[str, str]) -> Dict[str, str]: + response = copy.deepcopy(response) + if "response" in response: + resp_str = response["response"] + if "" in resp_str: + # only include descriptions for these, the rest will just have executing + # code + description_map = { + "open_code_artifact": "Reading file.", + "create_code_artifact": "Creating file.", + "edit_code_artifact": "Editing file.", + "generate_vision_code": "Generating vision code.", + "edit_vision_code": "Editing vision code.", + } + description = "" + for k, v in description_map.items(): + if k in resp_str: + description += v + " " + if description == "": + description = "Executing code." + resp_str = resp_str[resp_str.find("") :] + resp_str = description + resp_str + response["response"] = resp_str + return response + + class VisionAgent(Agent): """Vision Agent is an agent that can chat with the user and call tools or other agents to generate code for it. Vision Agent uses python code to execute actions @@ -335,8 +361,18 @@ def chat_with_code( response = run_conversation(self.agent, int_chat) if self.verbosity >= 1: _LOGGER.info(response) - int_chat.append({"role": "assistant", "content": str(response)}) - orig_chat.append({"role": "assistant", "content": str(response)}) + int_chat.append( + { + "role": "assistant", + "content": str(add_step_descriptions(response)), + } + ) + orig_chat.append( + { + "role": "assistant", + "content": str(add_step_descriptions(response)), + } + ) # sometimes it gets stuck in a loop, so we force it to exit if last_response == response: @@ -382,6 +418,16 @@ def chat_with_code( obs_chat_elt: Message = {"role": "observation", "content": obs} if media_obs and result.success: + # for view_media_artifact, we need to ensure the media is loaded + # locally so the conversation agent can actually see it + code_interpreter.download_file( + str(remote_artifacts_path.name), + str(self.local_artifacts_path), + ) + artifacts.load( + self.local_artifacts_path, + Path(self.local_artifacts_path).parent, + ) obs_chat_elt["media"] = [ Path(self.local_artifacts_path).parent / media_ob for media_ob in media_obs @@ -407,8 +453,9 @@ def chat_with_code( code_interpreter.download_file( str(remote_artifacts_path.name), str(self.local_artifacts_path) ) - artifacts.load(self.local_artifacts_path) - artifacts.save() + artifacts.load( + self.local_artifacts_path, Path(self.local_artifacts_path).parent + ) return orig_chat, artifacts def streaming_message(self, message: Dict[str, Any]) -> None: diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index dc910300..78a2ecae 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -92,19 +92,26 @@ def __init__(self, remote_save_path: Union[str, Path]) -> None: self.code_sandbox_runtime = None - def load(self, file_path: Union[str, Path]) -> None: - """Loads are artifacts into the remote environment. If an artifact value is None - it will skip loading it. + def load( + self, + artifacts_path: Union[str, Path], + load_to: Optional[Union[str, Path]] = None, + ) -> None: + """Loads are artifacts into the load_to path. If load_to is None, it will load + into remote_save_path. If an artifact value is None it will skip loading it. Parameters: - file_path (Union[str, Path]): The file path to load the artifacts from + artifacts_path (Union[str, Path]): The file path to load the artifacts from """ - with open(file_path, "rb") as f: + with open(artifacts_path, "rb") as f: self.artifacts = pkl.load(f) + + load_to = self.remote_save_path.parent if load_to is None else Path(load_to) + for k, v in self.artifacts.items(): if v is not None: mode = "w" if isinstance(v, str) else "wb" - with open(self.remote_save_path.parent / k, mode) as f: + with open(load_to / k, mode) as f: f.write(v) def show(self, uploaded_file_path: Optional[Union[str, Path]] = None) -> str: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 71646c45..bf4da892 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -700,6 +700,7 @@ def countgd_counting( {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ + image_size = image.shape[:2] buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] prompt = prompt.replace(", ", " .") @@ -712,7 +713,7 @@ def countgd_counting( bboxes_formatted = [ ODResponseData( label=bbox["label"], - bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + bbox=normalize_bbox(bbox["bounding_box"], image_size), score=round(bbox["score"], 2), ) for bbox in bboxes_per_frame @@ -757,6 +758,7 @@ def countgd_example_based_counting( {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, ] """ + image_size = image.shape[:2] buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] visual_prompts = [ @@ -771,7 +773,7 @@ def countgd_example_based_counting( bboxes_formatted = [ ODResponseData( label=bbox["label"], - bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + bbox=normalize_bbox(bbox["bounding_box"], image_size), score=round(bbox["score"], 2), ) for bbox in bboxes_per_frame