diff --git a/tests/test_llm.py b/tests/test_llm.py index 0a671ca5..bbcc203e 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -18,7 +18,7 @@ def test_generate_with_mock(openai_llm_mock): # noqa: F811 response = llm.generate("test prompt") assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_once_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) @@ -31,7 +31,7 @@ def test_chat_with_mock(openai_llm_mock): # noqa: F811 response = llm.chat([{"role": "user", "content": "test prompt"}]) assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_once_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) @@ -44,14 +44,14 @@ def test_call_with_mock(openai_llm_mock): # noqa: F811 response = llm("test prompt") assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_once_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) response = llm([{"role": "user", "content": "test prompt"}]) assert response == "mocked response" openai_llm_mock.chat.completions.create.assert_called_with( - model="gpt-4-turbo-preview", + model="gpt-4-turbo", messages=[{"role": "user", "content": "test prompt"}], ) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 10c98735..4c193aae 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -37,10 +37,10 @@ def parse_json(s: str) -> Any: s = ( - s.replace(": true", ": True") - .replace(": false", ": False") - .replace(":true", ": True") - .replace(":false", ": False") + s.replace(": True", ": true") + .replace(": False", ": false") + .replace(":True", ": true") + .replace(":False", ": false") .replace("```", "") .strip() ) @@ -62,6 +62,19 @@ def format_tools(tools: Dict[int, Any]) -> str: return tool_str +def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str: + usage = [] + name_to_usage = {v["name"]: v["usage"] for v in tools.values()} + for tool_res in tool_result: + if "tool_name" in tool_res: + usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]])) + + usage_str = "" + for tool_name, tool_usage in usage: + usage_str += f"{tool_name} - {tool_usage}\n" + return usage_str + + def topological_sort(tasks: List[Dict]) -> List[Dict]: in_degree = {task["id"]: 0 for task in tasks} for task in tasks: @@ -255,7 +268,8 @@ def self_reflect( ) -> str: prompt = VISION_AGENT_REFLECTION.format( question=question, - tools=format_tools(tools), + tools=format_tools({k: v["description"] for k, v in tools.items()}), + tool_usage=format_tool_usage(tools, tool_result), tool_results=str(tool_result), final_answer=final_answer, ) @@ -268,41 +282,28 @@ def self_reflect( return reflect_model(prompt) -def parse_reflect(reflect: str) -> bool: - # GPT-4V has a hard time following directions, so make the criteria less strict - return ( +def parse_reflect(reflect: str) -> Any: + reflect = reflect.strip() + try: + return parse_json(reflect) + except Exception: + _LOGGER.error(f"Failed parse json reflection: {reflect}") + # LMMs have a hard time following directions, so make the criteria less strict + finish = ( "finish" in reflect.lower() and len(reflect) < 100 ) or "finish" in reflect.lower()[-10:] - - -def visualize_result(all_tool_results: List[Dict]) -> List[str]: - image_to_data: Dict[str, Dict] = {} - for tool_result in all_tool_results: - if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]: - continue - - parameters = tool_result["parameters"] - # parameters can either be a dictionary or list, parameters can also be malformed - # becaus the LLM builds them - if isinstance(parameters, dict): - if "image" not in parameters: - continue - parameters = [parameters] - elif isinstance(tool_result["parameters"], list): - if len(tool_result["parameters"]) < 1 or ( - "image" not in tool_result["parameters"][0] - ): - continue - - for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful - if not isinstance(call_result, dict): - continue - if "bboxes" not in call_result: - continue - - # if the call was successful, then we can add the image data - image = param["image"] + return {"Finish": finish, "Reflection": reflect} + + +def _handle_extract_frames( + image_to_data: Dict[str, Dict], tool_result: Dict +) -> Dict[str, Dict]: + image_to_data = image_to_data.copy() + # handle extract_frames_ case, useful if it extracts frames but doesn't do + # any following processing + for video_file_output in tool_result["call_results"]: + for frame, _ in video_file_output: + image = frame if image not in image_to_data: image_to_data[image] = { "bboxes": [], @@ -310,17 +311,72 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: "labels": [], "scores": [], } + return image_to_data + + +def _handle_viz_tools( + image_to_data: Dict[str, Dict], tool_result: Dict +) -> Dict[str, Dict]: + image_to_data = image_to_data.copy() + + # handle grounding_sam_ and grounding_dino_ + parameters = tool_result["parameters"] + # parameters can either be a dictionary or list, parameters can also be malformed + # becaus the LLM builds them + if isinstance(parameters, dict): + if "image" not in parameters: + return image_to_data + parameters = [parameters] + elif isinstance(tool_result["parameters"], list): + if len(tool_result["parameters"]) < 1 or ( + "image" not in tool_result["parameters"][0] + ): + return image_to_data + + for param, call_result in zip(parameters, tool_result["call_results"]): + # calls can fail, so we need to check if the call was successful + if not isinstance(call_result, dict) or "bboxes" not in call_result: + return image_to_data + + # if the call was successful, then we can add the image data + image = param["image"] + if image not in image_to_data: + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } + + image_to_data[image]["bboxes"].extend(call_result["bboxes"]) + image_to_data[image]["labels"].extend(call_result["labels"]) + image_to_data[image]["scores"].extend(call_result["scores"]) + if "masks" in call_result: + image_to_data[image]["masks"].extend(call_result["masks"]) + + return image_to_data + - image_to_data[image]["bboxes"].extend(call_result["bboxes"]) - image_to_data[image]["labels"].extend(call_result["labels"]) - image_to_data[image]["scores"].extend(call_result["scores"]) - if "masks" in call_result: - image_to_data[image]["masks"].extend(call_result["masks"]) +def visualize_result(all_tool_results: List[Dict]) -> List[str]: + image_to_data: Dict[str, Dict] = {} + for tool_result in all_tool_results: + # only handle bbox/mask tools or frame extraction + if tool_result["tool_name"] not in [ + "grounding_sam_", + "grounding_dino_", + "extract_frames_", + ]: + continue + + if tool_result["tool_name"] == "extract_frames_": + image_to_data = _handle_extract_frames(image_to_data, tool_result) + else: + image_to_data = _handle_viz_tools(image_to_data, tool_result) visualized_images = [] - for image in image_to_data: - image_path = Path(image) - image_data = image_to_data[image] + for image_str in image_to_data: + image_path = Path(image_str) + image_data = image_to_data[image_str] image = overlay_masks(image_path, image_data) image = overlay_bboxes(image, image_data) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: @@ -374,7 +430,9 @@ def __init__( OpenAILLM(temperature=0.1) if answer_model is None else answer_model ) self.reflect_model = ( - OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model + OpenAILMM(json_mode=True, temperature=0.1) + if reflect_model is None + else reflect_model ) self.max_retries = max_retries self.tools = TOOLS @@ -470,11 +528,12 @@ def chat_with_workflow( visualized_output[0] if len(visualized_output) > 0 else image, ) self.log_progress(f"Reflection: {reflection}") - if parse_reflect(reflection): + parsed_reflection = parse_reflect(reflection) + if parsed_reflection["Finish"]: break else: - reflections += "\n" + reflection - # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. + reflections += "\n" + parsed_reflection["Reflection"] + # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( f"The Vision Agent has concluded this chat. {final_answer}" ) diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index a54ae6a6..cd7878c0 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -1,4 +1,14 @@ -VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences. +VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. + +Please note that: +1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like: +{{"Finish": true, "Reflection": "The agent's answer was correct."}} +2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or, using your own judgement, utilized incorrectly. +3. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like: + {{"Finish": false, "Reflection": "I can see from teh visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters: + Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives. + Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}} +4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true. User's question: {question} @@ -8,6 +18,9 @@ Tasks and tools used: {tool_results} +Tool's used API documentation: +{tool_usage} + Final answer: {final_answer} diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 9022ef73..9352f58b 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -33,7 +33,7 @@ class OpenAILLM(LLM): def __init__( self, - model_name: str = "gpt-4-turbo-preview", + model_name: str = "gpt-4-turbo", api_key: Optional[str] = None, json_mode: bool = False, **kwargs: Any diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 0d63b158..738ac004 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -99,9 +99,10 @@ class OpenAILMM(LMM): def __init__( self, - model_name: str = "gpt-4-vision-preview", + model_name: str = "gpt-4-turbo", api_key: Optional[str] = None, max_tokens: int = 1024, + json_mode: bool = False, **kwargs: Any, ): if not api_key: @@ -111,7 +112,10 @@ def __init__( self.client = OpenAI(api_key=api_key) self.model_name = model_name - self.max_tokens = max_tokens + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = max_tokens + if json_mode: + kwargs["response_format"] = {"type": "json_object"} self.kwargs = kwargs def __call__( @@ -153,7 +157,7 @@ def chat( ) response = self.client.chat.completions.create( - model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore + model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content) @@ -181,7 +185,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore + model=self.model_name, messages=message, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content)