From 4001524deae74f8a5d3433a16ae21eaeb28358a2 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 10:40:40 -0700 Subject: [PATCH 1/9] visualized output/reflection to handle extract_frames_ --- vision_agent/agent/vision_agent.py | 88 ++++++++++++++++++------------ 1 file changed, 52 insertions(+), 36 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 2f1d58b4..b9b91919 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -278,44 +278,59 @@ def parse_reflect(reflect: str) -> bool: def visualize_result(all_tool_results: List[Dict]) -> List[str]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: - if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]: + # only handle bbox/mask tools or frame extraction + if tool_result["tool_name"] not in [ + "grounding_sam_", + "grounding_dino_", + "extract_frames_", + ]: continue - parameters = tool_result["parameters"] - # parameters can either be a dictionary or list, parameters can also be malformed - # becaus the LLM builds them - if isinstance(parameters, dict): - if "image" not in parameters: - continue - parameters = [parameters] - elif isinstance(tool_result["parameters"], list): - if len(tool_result["parameters"]) < 1 or ( - "image" not in tool_result["parameters"][0] - ): - continue - - for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful - if not isinstance(call_result, dict): - continue - if "bboxes" not in call_result: - continue - - # if the call was successful, then we can add the image data - image = param["image"] - if image not in image_to_data: - image_to_data[image] = { - "bboxes": [], - "masks": [], - "labels": [], - "scores": [], - } - - image_to_data[image]["bboxes"].extend(call_result["bboxes"]) - image_to_data[image]["labels"].extend(call_result["labels"]) - image_to_data[image]["scores"].extend(call_result["scores"]) - if "masks" in call_result: - image_to_data[image]["masks"].extend(call_result["masks"]) + if tool_result["tool_name"] == "extract_frames_": + for video_file_output in tool_result["call_results"]: + for frame, _ in video_file_output: + image = frame + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + else: # handle grounding_sam_ and grounding_dino_ + parameters = tool_result["parameters"] + # parameters can either be a dictionary or list, parameters can also be malformed + # becaus the LLM builds them + if isinstance(parameters, dict): + if "image" not in parameters: + continue + parameters = [parameters] + elif isinstance(tool_result["parameters"], list): + if len(tool_result["parameters"]) < 1 or ( + "image" not in tool_result["parameters"][0] + ): + continue + + for param, call_result in zip(parameters, tool_result["call_results"]): + # calls can fail, so we need to check if the call was successful + if not isinstance(call_result, dict) or "bboxes" not in call_result: + continue + + # if the call was successful, then we can add the image data + image = param["image"] + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + + image_to_data[image]["bboxes"].extend(call_result["bboxes"]) + image_to_data[image]["labels"].extend(call_result["labels"]) + image_to_data[image]["scores"].extend(call_result["scores"]) + if "masks" in call_result: + image_to_data[image]["masks"].extend(call_result["masks"]) visualized_images = [] for image in image_to_data: @@ -459,6 +474,7 @@ def chat_with_workflow( self.answer_model, question, answers, reflections ) + __import__("ipdb").set_trace() visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) reflection = self_reflect( From dae1031f17123d4f217c154eed00d19669077bab Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 10:42:17 -0700 Subject: [PATCH 2/9] remove ipdb --- vision_agent/agent/vision_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b9b91919..a04e65eb 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -474,7 +474,6 @@ def chat_with_workflow( self.answer_model, question, answers, reflections ) - __import__("ipdb").set_trace() visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) reflection = self_reflect( From 3af591da70a51e660d696d2443df4f1e0b97be10 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 14:19:23 -0700 Subject: [PATCH 3/9] added json mode for lmm, upgraded gpt-4-turbo --- vision_agent/agent/vision_agent.py | 24 +++++++++++++++--------- vision_agent/llm/llm.py | 2 +- vision_agent/lmm/lmm.py | 12 ++++++++---- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index a04e65eb..8e02db0a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -255,7 +255,7 @@ def self_reflect( ) -> str: prompt = VISION_AGENT_REFLECTION.format( question=question, - tools=format_tools(tools), + tools=format_tools({k: v["description"] for k, v in tools.items()}), tool_results=str(tool_result), final_answer=final_answer, ) @@ -268,11 +268,16 @@ def self_reflect( return reflect_model(prompt) -def parse_reflect(reflect: str) -> bool: - # GPT-4V has a hard time following directions, so make the criteria less strict - return ( +def parse_reflect(reflect: str) -> Dict[str, Any]: + try: + return parse_json(reflect) + except Exception: + _LOGGER.error(f"Failed parse json reflection: {reflect}") + # LMMs have a hard time following directions, so make the criteria less strict + finish = ( "finish" in reflect.lower() and len(reflect) < 100 ) or "finish" in reflect.lower()[-10:] + return {"Finish": finish, "Reflection": reflect} def visualize_result(all_tool_results: List[Dict]) -> List[str]: @@ -389,7 +394,7 @@ def __init__( OpenAILLM(temperature=0.1) if answer_model is None else answer_model ) self.reflect_model = ( - OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model + OpenAILMM(json_mode=True, temperature=0.1) if reflect_model is None else reflect_model ) self.max_retries = max_retries self.tools = TOOLS @@ -485,13 +490,14 @@ def chat_with_workflow( visualized_output[0] if len(visualized_output) > 0 else image, ) self.log_progress(f"Reflection: {reflection}") - if parse_reflect(reflection): + parsed_reflection = parse_reflect(reflection) + if parsed_reflection["Finish"]: break else: - reflections += "\n" + reflection - # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. + reflections += "\n" + parsed_reflection["Reflection"] + # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( - f"The Vision Agent has concluded this chat. {final_answer}" + f"The Vision Agent has concluded this chat. {final_answer}" ) if visualize_output: diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 9022ef73..9352f58b 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -33,7 +33,7 @@ class OpenAILLM(LLM): def __init__( self, - model_name: str = "gpt-4-turbo-preview", + model_name: str = "gpt-4-turbo", api_key: Optional[str] = None, json_mode: bool = False, **kwargs: Any diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 0d63b158..738ac004 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -99,9 +99,10 @@ class OpenAILMM(LMM): def __init__( self, - model_name: str = "gpt-4-vision-preview", + model_name: str = "gpt-4-turbo", api_key: Optional[str] = None, max_tokens: int = 1024, + json_mode: bool = False, **kwargs: Any, ): if not api_key: @@ -111,7 +112,10 @@ def __init__( self.client = OpenAI(api_key=api_key) self.model_name = model_name - self.max_tokens = max_tokens + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = max_tokens + if json_mode: + kwargs["response_format"] = {"type": "json_object"} self.kwargs = kwargs def __call__( @@ -153,7 +157,7 @@ def chat( ) response = self.client.chat.completions.create( - model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore + model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content) @@ -181,7 +185,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore + model=self.model_name, messages=message, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content) From 9773aced578fa6b7d495ced07cd9a6f3b24dd753 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 14:19:37 -0700 Subject: [PATCH 4/9] updated reflection prompt --- vision_agent/agent/vision_agent_prompts.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index a54ae6a6..9a91b46e 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -1,4 +1,13 @@ -VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences. +VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. + +Please note that: +1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like: +{{"Finish": true, "Reflection": "The agent's answer was correct."}} +2. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like: + {{"Finish": false, "Reflection": "The agent's answer was incorrect. The agent should use the following tools with the following parameters: + Step 1: Use 'grounding_dion_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives. + Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}} +3. If the task cannot be completed with the existing tools, set "Finish" to true. User's question: {question} From e99f8380c9adc20014f7d5182043a56d8c82a89b Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 14:43:16 -0700 Subject: [PATCH 5/9] refactor to make function simpler --- vision_agent/agent/vision_agent.py | 107 +++++++++++++++++------------ 1 file changed, 62 insertions(+), 45 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 8e02db0a..87b600e1 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -268,7 +268,7 @@ def self_reflect( return reflect_model(prompt) -def parse_reflect(reflect: str) -> Dict[str, Any]: +def parse_reflect(reflect: str) -> Any: try: return parse_json(reflect) except Exception: @@ -280,6 +280,64 @@ def parse_reflect(reflect: str) -> Dict[str, Any]: return {"Finish": finish, "Reflection": reflect} +def _handle_extract_frames(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]: + image_to_data = image_to_data.copy() + # handle extract_frames_ case, useful if it extracts frames but doesn't do + # any following processing + for video_file_output in tool_result["call_results"]: + for frame, _ in video_file_output: + image = frame + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + return image_to_data + + +def _handle_viz_tools(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]: + image_to_data = image_to_data.copy() + + # handle grounding_sam_ and grounding_dino_ + parameters = tool_result["parameters"] + # parameters can either be a dictionary or list, parameters can also be malformed + # becaus the LLM builds them + if isinstance(parameters, dict): + if "image" not in parameters: + return image_to_data + parameters = [parameters] + elif isinstance(tool_result["parameters"], list): + if len(tool_result["parameters"]) < 1 or ( + "image" not in tool_result["parameters"][0] + ): + return image_to_data + + for param, call_result in zip(parameters, tool_result["call_results"]): + # calls can fail, so we need to check if the call was successful + if not isinstance(call_result, dict) or "bboxes" not in call_result: + return image_to_data + + # if the call was successful, then we can add the image data + image = param["image"] + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + + image_to_data[image]["bboxes"].extend(call_result["bboxes"]) + image_to_data[image]["labels"].extend(call_result["labels"]) + image_to_data[image]["scores"].extend(call_result["scores"]) + if "masks" in call_result: + image_to_data[image]["masks"].extend(call_result["masks"]) + + return image_to_data + + def visualize_result(all_tool_results: List[Dict]) -> List[str]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: @@ -292,50 +350,9 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: continue if tool_result["tool_name"] == "extract_frames_": - for video_file_output in tool_result["call_results"]: - for frame, _ in video_file_output: - image = frame - if image not in image_to_data: - image_to_data[image] = { - "bboxes": [], - "masks": [], - "labels": [], - "scores": [], - } - else: # handle grounding_sam_ and grounding_dino_ - parameters = tool_result["parameters"] - # parameters can either be a dictionary or list, parameters can also be malformed - # becaus the LLM builds them - if isinstance(parameters, dict): - if "image" not in parameters: - continue - parameters = [parameters] - elif isinstance(tool_result["parameters"], list): - if len(tool_result["parameters"]) < 1 or ( - "image" not in tool_result["parameters"][0] - ): - continue - - for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful - if not isinstance(call_result, dict) or "bboxes" not in call_result: - continue - - # if the call was successful, then we can add the image data - image = param["image"] - if image not in image_to_data: - image_to_data[image] = { - "bboxes": [], - "masks": [], - "labels": [], - "scores": [], - } - - image_to_data[image]["bboxes"].extend(call_result["bboxes"]) - image_to_data[image]["labels"].extend(call_result["labels"]) - image_to_data[image]["scores"].extend(call_result["scores"]) - if "masks" in call_result: - image_to_data[image]["masks"].extend(call_result["masks"]) + image_to_data = _handle_extract_frames(image_to_data, tool_result) + else: + image_to_data = _handle_viz_tools(image_to_data, tool_result) visualized_images = [] for image in image_to_data: From 085e12ff1643aa9fb9f4f161a1cc6f683e1801dc Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 15:08:52 -0700 Subject: [PATCH 6/9] updated reflection prompt, add tool usage doc --- vision_agent/agent/vision_agent.py | 23 ++++++++++++++++++---- vision_agent/agent/vision_agent_prompts.py | 14 ++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 87b600e1..10affc65 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -37,10 +37,10 @@ def parse_json(s: str) -> Any: s = ( - s.replace(": true", ": True") - .replace(": false", ": False") - .replace(":true", ": True") - .replace(":false", ": False") + s.replace(": True", ": true") + .replace(": False", ": false") + .replace(":True", ": true") + .replace(":False", ": false") .replace("```", "") .strip() ) @@ -62,6 +62,19 @@ def format_tools(tools: Dict[int, Any]) -> str: return tool_str +def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str: + usage = [] + name_to_usage = {v["name"]: v["usage"] for v in tools.values()} + for tool_res in tool_result: + if "tool_name" in tool_res: + usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]])) + + usage_str = "" + for tool_name, tool_usage in usage: + usage_str += f"{tool_name} - {tool_usage}\n" + return usage_str + + def topological_sort(tasks: List[Dict]) -> List[Dict]: in_degree = {task["id"]: 0 for task in tasks} for task in tasks: @@ -256,6 +269,7 @@ def self_reflect( prompt = VISION_AGENT_REFLECTION.format( question=question, tools=format_tools({k: v["description"] for k, v in tools.items()}), + tool_usage=format_tool_usage(tools, tool_result), tool_results=str(tool_result), final_answer=final_answer, ) @@ -269,6 +283,7 @@ def self_reflect( def parse_reflect(reflect: str) -> Any: + reflect = reflect.strip() try: return parse_json(reflect) except Exception: diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 9a91b46e..cd7878c0 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -1,13 +1,14 @@ -VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. +VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. Please note that: 1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like: {{"Finish": true, "Reflection": "The agent's answer was correct."}} -2. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like: - {{"Finish": false, "Reflection": "The agent's answer was incorrect. The agent should use the following tools with the following parameters: - Step 1: Use 'grounding_dion_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives. +2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or, using your own judgement, utilized incorrectly. +3. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like: + {{"Finish": false, "Reflection": "I can see from teh visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters: + Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives. Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}} -3. If the task cannot be completed with the existing tools, set "Finish" to true. +4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true. User's question: {question} @@ -17,6 +18,9 @@ Tasks and tools used: {tool_results} +Tool's used API documentation: +{tool_usage} + Final answer: {final_answer} From d41c6b2ec6682350f372f84e29d96873c7839689 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 15:11:43 -0700 Subject: [PATCH 7/9] fixed format issue --- vision_agent/agent/vision_agent.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 10affc65..c2ad1883 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -295,7 +295,9 @@ def parse_reflect(reflect: str) -> Any: return {"Finish": finish, "Reflection": reflect} -def _handle_extract_frames(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]: +def _handle_extract_frames( + image_to_data: Dict[str, Dict], tool_result: Dict +) -> Dict[str, Dict]: image_to_data = image_to_data.copy() # handle extract_frames_ case, useful if it extracts frames but doesn't do # any following processing @@ -312,7 +314,9 @@ def _handle_extract_frames(image_to_data: Dict[str, Dict], tool_result: Dict) -> return image_to_data -def _handle_viz_tools(image_to_data: Dict[str, Dict], tool_result: Dict) -> Dict[str, Dict]: +def _handle_viz_tools( + image_to_data: Dict[str, Dict], tool_result: Dict +) -> Dict[str, Dict]: image_to_data = image_to_data.copy() # handle grounding_sam_ and grounding_dino_ @@ -426,7 +430,9 @@ def __init__( OpenAILLM(temperature=0.1) if answer_model is None else answer_model ) self.reflect_model = ( - OpenAILMM(json_mode=True, temperature=0.1) if reflect_model is None else reflect_model + OpenAILMM(json_mode=True, temperature=0.1) + if reflect_model is None + else reflect_model ) self.max_retries = max_retries self.tools = TOOLS From 7f2140f1058d50fffeddaebdd1db253d52d36ec1 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 15:16:59 -0700 Subject: [PATCH 8/9] fixed type issue --- vision_agent/agent/vision_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index c2ad1883..8553eb25 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -374,9 +374,9 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: image_to_data = _handle_viz_tools(image_to_data, tool_result) visualized_images = [] - for image in image_to_data: - image_path = Path(image) - image_data = image_to_data[image] + for image_str in image_to_data: + image_path = Path(image_str) + image_data = image_to_data[image_str] image = overlay_masks(image_path, image_data) image = overlay_bboxes(image, image_data) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: From 9f8b7e24ccee6a3b5501454f48d08cc8776b4d53 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 15:35:14 -0700 Subject: [PATCH 9/9] fixed test case --- tests/test_llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_llm.py b/tests/test_llm.py index 0a671ca5..bbcc203e 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -18,7 +18,7 @@ def test_generate_with_mock(openai_llm_mock): # noqa: F811 response = llm.generate("test prompt") assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_once_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) @@ -31,7 +31,7 @@ def test_chat_with_mock(openai_llm_mock): # noqa: F811 response = llm.chat([{"role": "user", "content": "test prompt"}]) assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_once_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) @@ -44,14 +44,14 @@ def test_call_with_mock(openai_llm_mock): # noqa: F811 response = llm("test prompt") assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_once_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) response = llm([{"role": "user", "content": "test prompt"}]) assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], )