diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index d7886178..9180b38b 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -5,7 +5,7 @@ import sys import tempfile from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from langsmith import traceable from PIL import Image @@ -20,9 +20,11 @@ CODE, FIX_BUG, FULL_TASK, + PICK_PLAN, PLAN, - REFLECT, + PREVIOUS_FAILED, SIMPLE_TEST, + TEST_PLANS, USER_REQ, ) from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM @@ -80,6 +82,15 @@ def format_memory(memory: List[Dict[str, str]]) -> str: return output_str +def format_plans(plans: Dict[str, Any]) -> str: + plan_str = "" + for k, v in plans.items(): + plan_str += f"{k}:\n" + plan_str += "-" + "\n-".join([e["instructions"] for e in v]) + + return plan_str + + def extract_code(code: str) -> str: if "\n```python" in code: start = "\n```python" @@ -140,12 +151,12 @@ def extract_image( @traceable -def write_plan( +def write_plans( chat: List[Message], tool_desc: str, working_memory: str, model: LMM, -) -> List[Dict[str, str]]: +) -> Dict[str, Any]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": raise ValueError("Last chat message must be from the user.") @@ -154,14 +165,84 @@ def write_plan( context = USER_REQ.format(user_request=user_request) prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) chat[-1]["content"] = prompt - return extract_json(model.chat(chat))["plan"] # type: ignore + return extract_json(model.chat(chat)) + + +@traceable +def pick_plan( + chat: List[Message], + plans: Dict[str, Any], + tool_info: str, + model: LMM, + code_interpreter: CodeInterpreter, + verbosity: int = 0, +) -> Tuple[str, str]: + chat = copy.deepcopy(chat) + if chat[-1]["role"] != "user": + raise ValueError("Last chat message must be from the user.") + + plan_str = format_plans(plans) + prompt = TEST_PLANS.format( + docstring=tool_info, plans=plan_str, previous_attempts="" + ) + + code = extract_code(model(prompt)) + tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code)) + tool_output_str = "" + if len(tool_output.logs.stdout) > 0: + tool_output_str = tool_output.logs.stdout[0] + + if verbosity >= 1: + _print_code("Initial code and tests:", code) + _LOGGER.info(f"Initial code execution result:\n{tool_output.text()}") + + # retry if the tool output is empty or code fails + count = 1 + while (not tool_output.success or tool_output_str == "") and count < 3: + prompt = TEST_PLANS.format( + docstring=tool_info, + plans=plan_str, + previous_attempts=PREVIOUS_FAILED.format( + code=code, error=tool_output.text() + ), + ) + code = extract_code(model(prompt)) + tool_output = code_interpreter.exec_isolation( + DefaultImports.prepend_imports(code) + ) + tool_output_str = "" + if len(tool_output.logs.stdout) > 0: + tool_output_str = tool_output.logs.stdout[0] + + if verbosity == 1: + _print_code("Code and test after attempted fix:", code) + _LOGGER.info(f"Code execution result after attempte {count}") + + count += 1 + + user_req = chat[-1]["content"] + context = USER_REQ.format(user_request=user_req) + # because the tool picker model gets the image as well, we have to be careful with + # how much text we send it, so we truncate the tool output to 20,000 characters + prompt = PICK_PLAN.format( + context=context, + plans=format_plans(plans), + tool_output=tool_output_str[:20_000], + ) + chat[-1]["content"] = prompt + best_plan = extract_json(model(chat)) + if verbosity >= 1: + _LOGGER.info(f"Best plan:\n{best_plan}") + return best_plan["best_plan"], tool_output_str @traceable def write_code( coder: LMM, chat: List[Message], + plan: str, tool_info: str, + tool_output: str, feedback: str, ) -> str: chat = copy.deepcopy(chat) @@ -171,7 +252,8 @@ def write_code( user_request = chat[-1]["content"] prompt = CODE.format( docstring=tool_info, - question=user_request, + question=FULL_TASK.format(user_request=user_request, subtasks=plan), + tool_output=tool_output, feedback=feedback, ) chat[-1]["content"] = prompt @@ -203,27 +285,11 @@ def write_test( return extract_code(tester(chat)) -@traceable -def reflect( - chat: List[Message], - plan: str, - code: str, - model: LMM, -) -> Dict[str, Union[str, bool]]: - chat = copy.deepcopy(chat) - if chat[-1]["role"] != "user": - raise ValueError("Last chat message must be from the user.") - - user_request = chat[-1]["content"] - context = USER_REQ.format(user_request=user_request) - prompt = REFLECT.format(context=context, plan=plan, code=code) - chat[-1]["content"] = prompt - return extract_json(model(chat)) - - def write_and_test_code( chat: List[Message], + plan: str, tool_info: str, + tool_output: str, tool_utils: str, working_memory: List[Dict[str, str]], coder: LMM, @@ -241,7 +307,14 @@ def write_and_test_code( "status": "started", } ) - code = write_code(coder, chat, tool_info, format_memory(working_memory)) + code = write_code( + coder, + chat, + plan, + tool_info, + tool_output, + format_memory(working_memory), + ) test = write_test( tester, chat, tool_utils, code, format_memory(working_memory), media ) @@ -412,11 +485,11 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( - plan: List[Dict[str, str]], + plans: Dict[str, List[Dict[str, str]]], tool_recommender: Sim, log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, -) -> str: +) -> Dict[str, str]: log_progress( { "type": "tools", @@ -425,27 +498,29 @@ def retrieve_tools( ) tool_info = [] tool_desc = [] - tool_list: List[Dict[str, str]] = [] - for task in plan: - tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3) - tool_info.extend([e["doc"] for e in tools]) - tool_desc.extend([e["desc"] for e in tools]) - tool_list.extend( - {"description": e["desc"], "documentation": e["doc"]} for e in tools - ) - log_progress( - { - "type": "tools", - "status": "completed", - "payload": list({v["description"]: v for v in tool_list}.values()), - } - ) + tool_lists: Dict[str, List[Dict[str, str]]] = {} + for k, plan in plans.items(): + tool_lists[k] = [] + for task in plan: + tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3) + tool_info.extend([e["doc"] for e in tools]) + tool_desc.extend([e["desc"] for e in tools]) + tool_lists[k].extend( + {"description": e["desc"], "documentation": e["doc"]} for e in tools + ) if verbosity == 2: tool_desc_str = "\n".join(set(tool_desc)) _LOGGER.info(f"Tools Description:\n{tool_desc_str}") - tool_info_set = set(tool_info) - return "\n\n".join(tool_info_set) + + tool_lists_unique = {} + for k in tool_lists: + tool_lists_unique[k] = "\n\n".join( + set(e["documentation"] for e in tool_lists[k]) + ) + all_tools = "\n\n".join(set(tool_info)) + tool_lists_unique["all"] = all_tools + return tool_lists_unique class VisionAgent(Agent): @@ -543,7 +618,6 @@ def __call__( def chat_with_workflow( self, chat: List[Message], - self_reflection: bool = False, display_visualization: bool = False, ) -> Dict[str, Any]: """Chat with Vision Agent and return intermediate information regarding the task. @@ -554,7 +628,6 @@ def chat_with_workflow( [{"role": "user", "content": "describe your task here..."}] or if it contains media files, it should be in the format of: [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}] - self_reflection (bool): Whether to reflect on the task and debug the code. display_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any). @@ -581,7 +654,10 @@ def chat_with_workflow( int_chat = cast( List[Message], - [{"role": c["role"], "content": c["content"]} for c in chat], + [ + {"role": c["role"], "content": c["content"], "media": c["media"]} + for c in chat + ], ) code = "" @@ -599,13 +675,45 @@ def chat_with_workflow( "status": "started", } ) - plan_i = write_plan( + plans = write_plans( int_chat, T.TOOL_DESCRIPTIONS, format_memory(working_memory), self.planner, ) - plan_i_str = "\n-".join([e["instructions"] for e in plan_i]) + + if self.verbosity >= 1: + for p in plans: + _LOGGER.info( + f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + ) + + tool_infos = retrieve_tools( + plans, + self.tool_recommender, + self.log_progress, + self.verbosity, + ) + best_plan, tool_output_str = pick_plan( + int_chat, + plans, + tool_infos["all"], + self.coder, + code_interpreter, + verbosity=self.verbosity, + ) + + if best_plan in plans and best_plan in tool_infos: + plan_i = plans[best_plan] + tool_info = tool_infos[best_plan] + else: + if self.verbosity >= 1: + _LOGGER.warning( + f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." + ) + k = list(plans.keys())[0] + plan_i = plans[k] + tool_info = tool_infos[k] self.log_progress( { @@ -616,18 +724,16 @@ def chat_with_workflow( ) if self.verbosity >= 1: _LOGGER.info( - f"\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) - tool_info = retrieve_tools( - plan_i, - self.tool_recommender, - self.log_progress, - self.verbosity, - ) results = write_and_test_code( - chat=int_chat, + chat=[ + {"role": c["role"], "content": c["content"]} for c in int_chat + ], + plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]), tool_info=tool_info, + tool_output=tool_output_str, tool_utils=T.UTILITIES_DOCSTRING, working_memory=working_memory, coder=self.coder, @@ -644,35 +750,6 @@ def chat_with_workflow( working_memory.extend(results["working_memory"]) # type: ignore plan.append({"code": code, "test": test, "plan": plan_i}) - if not self_reflection: - break - - self.log_progress( - { - "type": "self_reflection", - "status": "started", - } - ) - reflection = reflect( - int_chat, - FULL_TASK.format( - user_request=chat[0]["content"], subtasks=plan_i_str - ), - code, - self.planner, - ) - if self.verbosity > 0: - _LOGGER.info(f"Reflection: {reflection}") - feedback = cast(str, reflection["feedback"]) - success = cast(bool, reflection["success"]) - self.log_progress( - { - "type": "self_reflection", - "status": "completed" if success else "failed", - "payload": reflection, - } - ) - working_memory.append({"code": f"{code}\n{test}", "feedback": feedback}) retries += 1 execution_result = cast(Execution, results["test_result"]) diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 09105ca1..8f5e689b 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -19,7 +19,7 @@ PLAN = """ -**Context** +**Context**: {context} **Tools Available**: @@ -29,23 +29,110 @@ {feedback} **Instructions**: -1. Based on the context and tools you have available, write a plan of subtasks to achieve the user request. -2. Go over the users request step by step and ensure each step is represented as a clear subtask in your plan. +1. Based on the context and tools you have available, create a plan of subtasks to achieve the user request. +2. Output three different plans each utilize a different strategy or tool. Output a list of jsons in the following format ```json {{ - "plan": + "plan1": [ {{ "instructions": str # what you should do in this task associated with a tool }} - ] + ], + "plan2": ..., + "plan3": ... }} ``` """ + +TEST_PLANS = """ +**Role**: You are a software programmer responsible for testing different tools. + +**Task**: Your responsibility is to take a set of several plans and test the different tools for each plan. + +**Documentation**: +This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`. + +{docstring} + +**Plans**: +{plans} + +{previous_attempts} + +**Instructions**: +1. Write a program to load the media and call each tool and save it's output. +2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove any array types from the printed dictionary. +3. Print this final dictionary. + +**Example**: +plan1: +- Load the image from the provided file path 'image.jpg'. +- Use the 'owl_v2' tool with the prompt 'person' to detect and count the number of people in the image. +plan2: +- Load the image from the provided file path 'image.jpg'. +- Use the 'grounding_sam' tool with the prompt 'person' to detect and count the number of people in the image. +- Count the number of detected objects labeled as 'person'. +plan3: +- Load the image from the provided file path 'image.jpg'. +- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people. + +```python +from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting +image = load_image("image.jpg") +owl_v2_out = owl_v2("person", image) + +gsam_out = grounding_sam("person", image) +gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out] + +loca_out = loca_zero_shot_counting(image) +loca_out = loca_out["count"] + +final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}} +print(final_out) +``` +""" + + +PREVIOUS_FAILED = """ +**Previous Failed Attempts**: +You previously ran this code: +```python +{code} +``` + +But got the following error or no stdout: +{error} +""" + + +PICK_PLAN = """ +**Role**: You are a software programmer. + +**Task**: Your responsibility is to pick the best plan from the three plans provided. + +**Context**: +{context} + +**Plans**: +{plans} + +**Tool Output**: +{tool_output} + +**Instructions**: +1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request. +2. Output a JSON object with the following format: +{{ + "thoughts": str # your thought process for choosing the best plan + "best_plan": str # the best plan you have chosen +}} +""" + CODE = """ **Role**: You are a software programmer. @@ -64,6 +151,9 @@ **User Instructions**: {question} +**Tool Output**: +{tool_output} + **Previous Feedback**: {feedback} @@ -72,7 +162,6 @@ 2. **Algorithm/Method Selection**: Decide on the most efficient way. 3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode. 4. **Code Generation**: Translate your pseudocode into executable Python code. Ensure you use correct arguments, remember coordinates are always returned normalized from `vision_agent.tools`. All images from `vision_agent.tools` are in RGB format, red is (255, 0, 0) and blue is (0, 0, 255). -5. **Logging**: Log the output of the custom functions that were provided to you from `from vision_agent.tools import *`. Use a debug flag in the function parameters to toggle logging on and off. """ TEST = """ @@ -147,7 +236,6 @@ def find_text(image_path: str, text: str) -> str: ``` """ - SIMPLE_TEST = """ **Role**: As a tester, your task is to create a simple test case for the provided code. This test case should verify the fundamental functionality under normal conditions. diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e6c86844..d97a1d97 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,4 +1,5 @@ import base64 +import io import json import logging import os @@ -8,6 +9,7 @@ import requests from openai import AzureOpenAI, OpenAI +from PIL import Image import vision_agent.tools as T from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT @@ -15,12 +17,40 @@ _LOGGER = logging.getLogger(__name__) -def encode_image(image: Union[str, Path]) -> str: - with open(image, "rb") as f: - encoded_image = base64.b64encode(f.read()).decode("utf-8") +def encode_image_bytes(image: bytes) -> str: + image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore + buffer = io.BytesIO() + image.save(buffer, format="PNG") # type: ignore + encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") return encoded_image +def encode_media(media: Union[str, Path]) -> str: + extension = "png" + extension = Path(media).suffix + if extension.lower() not in { + ".jpg", + ".jpeg", + ".png", + ".webp", + ".bmp", + ".mp4", + ".mov", + }: + raise ValueError(f"Unsupported image extension: {extension}") + + image_bytes = b"" + if extension.lower() in {".mp4", ".mov"}: + frames = T.extract_frames(media) + image = frames[len(frames) // 2] + buffer = io.BytesIO() + Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG") + image_bytes = buffer.getvalue() + else: + image_bytes = open(media, "rb").read() + return encode_image_bytes(image_bytes) + + TextOrImage = Union[str, List[Union[str, Path]]] Message = Dict[str, TextOrImage] @@ -54,7 +84,7 @@ def __init__( self, model_name: str = "gpt-4o", api_key: Optional[str] = None, - max_tokens: int = 1024, + max_tokens: int = 4096, json_mode: bool = False, **kwargs: Any, ): @@ -97,20 +127,14 @@ def chat( fixed_c = {"role": c["role"]} fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore if "media" in c: - for image in c["media"]: - extension = Path(image).suffix - if extension.lower() == ".jpeg" or extension.lower() == ".jpg": - extension = "jpg" - elif extension.lower() == ".png": - extension = "png" - else: - raise ValueError(f"Unsupported image extension: {extension}") - encoded_image = encode_image(image) + for media in c["media"]: + encoded_media = encode_media(media) + fixed_c["content"].append( # type: ignore { "type": "image_url", "image_url": { - "url": f"data:image/{extension};base64,{encoded_image}", # type: ignore + "url": f"data:image/png;base64,{encoded_media}", # type: ignore "detail": "low", }, }, @@ -138,13 +162,12 @@ def generate( ] if media and len(media) > 0: for m in media: - extension = Path(m).suffix - encoded_image = encode_image(m) + encoded_media = encode_media(m) message[0]["content"].append( { "type": "image_url", "image_url": { - "url": f"data:image/{extension};base64,{encoded_image}", + "url": f"data:image/png;base64,{encoded_media}", "detail": "low", }, }, @@ -241,7 +264,7 @@ def __init__( api_key: Optional[str] = None, api_version: str = "2024-02-01", azure_endpoint: Optional[str] = None, - max_tokens: int = 1024, + max_tokens: int = 4096, json_mode: bool = False, **kwargs: Any, ): @@ -312,7 +335,7 @@ def chat( fixed_chat = [] for message in chat: if "media" in message: - message["images"] = [encode_image(m) for m in message["media"]] + message["images"] = [encode_media(m) for m in message["media"]] del message["media"] fixed_chat.append(message) url = f"{self.url}/chat" @@ -343,7 +366,7 @@ def generate( json_data = json.dumps(data) if media and len(media) > 0: for m in media: - data["images"].append(encode_image(m)) # type: ignore + data["images"].append(encode_media(m)) # type: ignore response = requests.post(url, data=json_data) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 61d118da..b354274f 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -11,19 +11,19 @@ clip, closest_box_distance, closest_mask_distance, + depth_anything_v2, + detr_segmentation, + dpt_hybrid_midas, extract_frames, florencev2_image_caption, - get_tool_documentation, florencev2_object_detection, - detr_segmentation, - depth_anything_v2, - generate_soft_edge_image, - dpt_hybrid_midas, + florencev2_roberta_vqa, generate_pose_image, + generate_soft_edge_image, + get_tool_documentation, git_vqa_v2, grounding_dino, grounding_sam, - florencev2_roberta_vqa, load_image, loca_visual_prompt_counting, loca_zero_shot_counting, diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index fb9b004b..821064c8 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -362,8 +362,10 @@ def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution": return Execution( error=Error( name=exec.__class__.__name__, - value=str(exec), - traceback_raw=traceback_raw, + value=_remove_escape_and_color_codes(str(exec)), + traceback_raw=[ + _remove_escape_and_color_codes(line) for line in traceback_raw + ], ) ) @@ -378,8 +380,11 @@ def from_e2b_execution(exec: E2BExecution) -> "Execution": # type: ignore error=( Error( name=exec.error.name, - value=exec.error.value, - traceback_raw=exec.error.traceback_raw, + value=_remove_escape_and_color_codes(exec.error.value), + traceback_raw=[ + _remove_escape_and_color_codes(line) + for line in exec.error.traceback_raw + ], ) if exec.error else None