From 296af2da991273d5b500daa8056354b59ffcc2e0 Mon Sep 17 00:00:00 2001 From: GitHub Actions Bot Date: Fri, 11 Oct 2024 02:57:23 +0000 Subject: [PATCH 1/3] [skip ci] chore(release): vision-agent 0.2.161 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ca6726a7..e541b5c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "vision-agent" -version = "0.2.160" +version = "0.2.161" description = "Toolset for Vision Agent" authors = ["Landing AI "] readme = "README.md" From ca1dc573cff9978e876c5ca22c5c6d06c3b8b099 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 10 Oct 2024 20:50:58 -0700 Subject: [PATCH 2/3] Separate out planner (#256) * separated out planner, renamed chat methods * fixed circular imports * added type for plan context * add planner as separate call to vision agent * export plan context * fixed circular imports * fixed wrong key * better json parsing * more test cases for json parsing * have planner visualize results * add more guard rails to remove double chat * revert changes with planning step for now * revert to original prompts * fix type issue * fix format issue * skip examples for flake8 * fix names and readme * fixed type error * fix countgd integ test * synced code with new code interpreter arg --- .github/workflows/ci_cd.yml | 2 +- README.md | 12 +- docs/index.md | 12 +- examples/chat/app.py | 32 +- tests/integ/test_tools.py | 2 +- tests/unit/test_utils.py | 7 + vision_agent/agent/__init__.py | 8 + vision_agent/agent/agent_utils.py | 78 ++- vision_agent/agent/vision_agent.py | 66 +- vision_agent/agent/vision_agent_coder.py | 652 +++++------------- .../agent/vision_agent_coder_prompts.py | 203 ------ vision_agent/agent/vision_agent_planner.py | 553 +++++++++++++++ .../agent/vision_agent_planner_prompts.py | 199 ++++++ vision_agent/tools/__init__.py | 1 - vision_agent/tools/meta_tools.py | 87 ++- 15 files changed, 1174 insertions(+), 740 deletions(-) create mode 100644 vision_agent/agent/vision_agent_planner.py create mode 100644 vision_agent/agent/vision_agent_planner_prompts.py diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index ce25f286..17757846 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -43,7 +43,7 @@ jobs: - name: Linting run: | # stop the build if there are Python syntax errors or undefined names - poetry run flake8 . --exclude .venv --count --show-source --statistics + poetry run flake8 . --exclude .venv,examples --count --show-source --statistics - name: Check Format run: | poetry run black --check --diff --color . diff --git a/README.md b/README.md index 29292d65..e34e265e 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ continuing, for example it may want to execute code and look at the output befor letting the user respond. ### Chatting and Artifacts -If you run `chat_with_code` you will also notice an `Artifact` object. `Artifact`'s +If you run `chat_with_artifacts` you will also notice an `Artifact` object. `Artifact`'s are a way to sync files between local and remote environments. The agent will read and write to the artifact object, which is just a pickle object, when it wants to save or load files. @@ -118,7 +118,7 @@ with open("image.png", "rb") as f: artifacts["image.png"] = f.read() agent = va.agent.VisionAgent() -response, artifacts = agent.chat_with_code( +response, artifacts = agent.chat_with_artifacts( [ { "role": "user", @@ -298,11 +298,11 @@ mode by passing in the verbose argument: ``` ### Detailed Usage -You can also have it return more information by calling `chat_with_workflow`. The format +You can also have it return more information by calling `generate_code`. The format of the input is a list of dictionaries with the keys `role`, `content`, and `media`: ```python ->>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}]) +>>> results = agent.generate_code([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}]) >>> print(results) { "code": "from vision_agent.tools import ..." @@ -331,7 +331,7 @@ conv = [ "media": ["workers.png"], } ] -result = agent.chat_with_workflow(conv) +result = agent.generate_code(conv) code = result["code"] conv.append({"role": "assistant", "content": code}) conv.append( @@ -340,7 +340,7 @@ conv.append( "content": "Can you also return the number of workers wearing safety gear?", } ) -result = agent.chat_with_workflow(conv) +result = agent.generate_code(conv) ``` diff --git a/docs/index.md b/docs/index.md index ee04f3d6..08c808a9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -97,7 +97,7 @@ continuing, for example it may want to execute code and look at the output befor letting the user respond. ### Chatting and Artifacts -If you run `chat_with_code` you will also notice an `Artifact` object. `Artifact`'s +If you run `chat_with_artifacts` you will also notice an `Artifact` object. `Artifact`'s are a way to sync files between local and remote environments. The agent will read and write to the artifact object, which is just a pickle object, when it wants to save or load files. @@ -114,7 +114,7 @@ with open("image.png", "rb") as f: artifacts["image.png"] = f.read() agent = va.agent.VisionAgent() -response, artifacts = agent.chat_with_code( +response, artifacts = agent.chat_with_artifacts( [ { "role": "user", @@ -294,11 +294,11 @@ mode by passing in the verbose argument: ``` ### Detailed Usage -You can also have it return more information by calling `chat_with_workflow`. The format +You can also have it return more information by calling `generate_code`. The format of the input is a list of dictionaries with the keys `role`, `content`, and `media`: ```python ->>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}]) +>>> results = agent.generate_code([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}]) >>> print(results) { "code": "from vision_agent.tools import ..." @@ -327,7 +327,7 @@ conv = [ "media": ["workers.png"], } ] -result = agent.chat_with_workflow(conv) +result = agent.generate_code(conv) code = result["code"] conv.append({"role": "assistant", "content": code}) conv.append( @@ -336,7 +336,7 @@ conv.append( "content": "Can you also return the number of workers wearing safety gear?", } ) -result = agent.chat_with_workflow(conv) +result = agent.generate_code(conv) ``` diff --git a/examples/chat/app.py b/examples/chat/app.py index 0389b2f1..b07e308e 100644 --- a/examples/chat/app.py +++ b/examples/chat/app.py @@ -27,7 +27,7 @@ "style": {"bottom": "calc(50% - 4.25rem", "right": "0.4rem"}, } # set artifacts remote_path to WORKSPACE -artifacts = va.tools.Artifacts(WORKSPACE / "artifacts.pkl") +artifacts = va.tools.meta_tools.Artifacts(WORKSPACE / "artifacts.pkl") if Path("artifacts.pkl").exists(): artifacts.load("artifacts.pkl") else: @@ -109,16 +109,26 @@ def main(): len(st.session_state.messages) == 0 or prompt != st.session_state.messages[-1]["content"] ): - st.session_state.messages.append( - {"role": "user", "content": prompt} - ) - messages.chat_message("user").write(prompt) - message_thread = threading.Thread( - target=update_messages, - args=(st.session_state.messages, message_lock), - ) - message_thread.daemon = True - message_thread.start() + # occassionally resends the last user message twice + user_messages = [ + msg + for msg in st.session_state.messages + if msg["role"] == "user" + ] + last_user_message = None + if len(user_messages) > 0: + last_user_message = user_messages[-1]["content"] + if last_user_message is None or last_user_message != prompt: + st.session_state.messages.append( + {"role": "user", "content": prompt} + ) + messages.chat_message("user").write(prompt) + message_thread = threading.Thread( + target=update_messages, + args=(st.session_state.messages, message_lock), + ) + message_thread.daemon = True + message_thread.start() st.session_state.input_text = "" with tabs[1]: diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 4f5c674f..9fd9f15c 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -413,4 +413,4 @@ def test_countgd_example_based_counting() -> None: image=img, ) assert len(result) == 24 - assert [res["label"] for res in result] == ["coin"] * 24 + assert [res["label"] for res in result] == ["object"] * 24 diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 73471a30..f82ec4c5 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -63,3 +63,10 @@ def test(): assert "import os" in out assert "!pip install pandas" not in out assert "!pip install dummy" in out + + +def test_chat_agent_case(): + a = """{"thoughts": "The user has chosen to use the plan with owl_v2 and specified a threshold of 0.4. I'll now generate the vision code based on this plan and the user's modification.", "response": "Certainly! I'll generate the code using owl_v2 with a threshold of 0.4 as you requested. Let me create that for you now.\n\ngenerate_vision_code(artifacts, 'count_workers_with_helmets.py', 'Can you write code to count the number of workers wearing helmets?', media=['/Users/dillonlaird/landing.ai/vision-agent/examples/chat/workspace/workers.png'], plan={'thoughts': 'Using owl_v2_image seems most appropriate as it can detect and count multiple objects given a text prompt. This tool is specifically designed for object detection tasks like counting workers wearing helmets.', 'instructions': ['Load the image using load_image(\'/Users/dillonlaird/landing.ai/vision-agent/examples/chat/workspace/workers.png\')', 'Use owl_v2_image with the prompt \'worker wearing helmet\' to detect and count workers with helmets', 'Count the number of detections returned by owl_v2_image to get the final count of workers wearing helmets']}, plan_thoughts='Use a threshold of 0.4 as specified by the user', plan_context_artifact='worker_helmet_plan.json')", "let_user_respond": false}""" + a_json = extract_json(a) + assert "thoughts" in a_json + assert "response" in a_json diff --git a/vision_agent/agent/__init__.py b/vision_agent/agent/__init__.py index 793f44cf..d143a2ab 100644 --- a/vision_agent/agent/__init__.py +++ b/vision_agent/agent/__init__.py @@ -7,3 +7,11 @@ OpenAIVisionAgentCoder, VisionAgentCoder, ) +from .vision_agent_planner import ( + AnthropicVisionAgentPlanner, + AzureVisionAgentPlanner, + OllamaVisionAgentPlanner, + OpenAIVisionAgentPlanner, + PlanContext, + VisionAgentPlanner, +) diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index 624ad608..9b7ea02a 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -2,10 +2,17 @@ import logging import re import sys -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional + +from rich.console import Console +from rich.style import Style +from rich.syntax import Syntax + +import vision_agent.tools as T logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) +_CONSOLE = Console() def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: @@ -41,11 +48,16 @@ def _strip_markdown_code(inp_str: str) -> str: def extract_json(json_str: str) -> Dict[str, Any]: json_str_mod = json_str.replace("\n", " ").strip() - json_str_mod = json_str_mod.replace("'", '"') json_str_mod = json_str_mod.replace(": True", ": true").replace( ": False", ": false" ) + # sometimes the json is in single quotes + try: + return json.loads(json_str_mod.replace("'", '"')) # type: ignore + except json.JSONDecodeError: + pass + try: return json.loads(json_str_mod) # type: ignore except json.JSONDecodeError: @@ -83,3 +95,65 @@ def remove_installs_from_code(code: str) -> str: pattern = r"\n!pip install.*?(\n|\Z)\n" code = re.sub(pattern, "", code, flags=re.DOTALL) return code + + +def format_memory(memory: List[Dict[str, str]]) -> str: + output_str = "" + for i, m in enumerate(memory): + output_str += f"### Feedback {i}:\n" + output_str += f"Code {i}:\n```python\n{m['code']}```\n\n" + output_str += f"Feedback {i}: {m['feedback']}\n\n" + if "edits" in m: + output_str += f"Edits {i}:\n{m['edits']}\n" + output_str += "\n" + + return output_str + + +def format_plans(plans: Dict[str, Any]) -> str: + plan_str = "" + for k, v in plans.items(): + plan_str += "\n" + f"{k}: {v['thoughts']}\n" + plan_str += " -" + "\n -".join([e for e in v["instructions"]]) + + return plan_str + + +class DefaultImports: + """Container for default imports used in the code execution.""" + + common_imports = [ + "import os", + "import numpy as np", + "from vision_agent.tools import *", + "from typing import *", + "from pillow_heif import register_heif_opener", + "register_heif_opener()", + ] + + @staticmethod + def to_code_string() -> str: + return "\n".join(DefaultImports.common_imports + T.__new_tools__) + + @staticmethod + def prepend_imports(code: str) -> str: + """Run this method to prepend the default imports to the code. + NOTE: be sure to run this method after the custom tools have been registered. + """ + return DefaultImports.to_code_string() + "\n\n" + code + + +def print_code(title: str, code: str, test: Optional[str] = None) -> None: + _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True)) + _CONSOLE.print("=" * 30 + " Code " + "=" * 30) + _CONSOLE.print( + Syntax( + DefaultImports.prepend_imports(code), + "python", + theme="gruvbox-dark", + line_numbers=True, + ) + ) + if test: + _CONSOLE.print("=" * 30 + " Test " + "=" * 30) + _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True)) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1a38468f..6e1621f0 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -14,8 +14,8 @@ VA_CODE, ) from vision_agent.lmm import LMM, AnthropicLMM, Message, OpenAILMM -from vision_agent.tools import META_TOOL_DOCSTRING from vision_agent.tools.meta_tools import ( + META_TOOL_DOCSTRING, Artifacts, check_and_load_image, use_extra_vision_agent_args, @@ -195,9 +195,8 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: """Initialize the VisionAgent. @@ -207,14 +206,17 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. - code_interpreter (Optional[CodeInterpreter]): if not None, use this CodeInterpreter + callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback + function to send intermediate update messages. + code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values + it can be one of: None, "local" or "e2b". If None, it will read from + the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter + object is provided it will use that. """ self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent self.max_iterations = 12 self.verbosity = verbosity - self.code_sandbox_runtime = code_sandbox_runtime self.code_interpreter = code_interpreter self.callback_message = callback_message if self.verbosity >= 1: @@ -233,7 +235,7 @@ def __call__( input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None, artifacts: Optional[Artifacts] = None, - ) -> List[Message]: + ) -> str: """Chat with VisionAgent and get the conversation response. Parameters: @@ -250,10 +252,28 @@ def __call__( input = [{"role": "user", "content": input}] if media is not None: input[0]["media"] = [media] - results, _ = self.chat_with_code(input, artifacts) - return results + results, _ = self.chat_with_artifacts(input, artifacts) + return results[-1]["content"] # type: ignore + + def chat( + self, + chat: List[Message], + ) -> List[Message]: + """Chat with VisionAgent, it will use code to execute actions to accomplish + its tasks. + + Parameters: + chat (List[Message]): A conversation in the format of: + [{"role": "user", "content": "describe your task here..."}] + or if it contains media files, it should be in the format of: + [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}] + + Returns: + List[Message]: The conversation response. + """ + return self.chat_with_artifacts(chat)[0] - def chat_with_code( + def chat_with_artifacts( self, chat: List[Message], artifacts: Optional[Artifacts] = None, @@ -287,11 +307,13 @@ def chat_with_code( # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") + # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues code_interpreter = ( self.code_interpreter if self.code_interpreter is not None + and not isinstance(self.code_interpreter, str) else CodeInterpreterFactory.new_instance( - code_sandbox_runtime=self.code_sandbox_runtime, + code_sandbox_runtime=self.code_interpreter, ) ) with code_interpreter: @@ -480,8 +502,8 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: """Initialize the VisionAgent using OpenAI LMMs. @@ -491,7 +513,12 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback + function to send intermediate update messages. + code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values + it can be one of: None, "local" or "e2b". If None, it will read from + the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter + object is provided it will use that. """ agent = OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent @@ -499,8 +526,8 @@ def __init__( agent, verbosity, local_artifacts_path, - code_sandbox_runtime, callback_message, + code_interpreter, ) @@ -510,8 +537,8 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: """Initialize the VisionAgent using Anthropic LMMs. @@ -521,7 +548,12 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback + function to send intermediate update messages. + code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values + it can be one of: None, "local" or "e2b". If None, it will read from + the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter + object is provided it will use that. """ agent = AnthropicLMM(temperature=0.0) if agent is None else agent @@ -529,6 +561,6 @@ def __init__( agent, verbosity, local_artifacts_path, - code_sandbox_runtime, callback_message, + code_interpreter, ) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 1e5030a2..f1246f09 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -2,32 +2,33 @@ import logging import os import sys -from json import JSONDecodeError from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast -from rich.console import Console -from rich.style import Style -from rich.syntax import Syntax from tabulate import tabulate import vision_agent.tools as T -from vision_agent.agent import Agent +from vision_agent.agent.agent import Agent from vision_agent.agent.agent_utils import ( + DefaultImports, extract_code, extract_json, + format_memory, + print_code, remove_installs_from_code, ) from vision_agent.agent.vision_agent_coder_prompts import ( CODE, FIX_BUG, FULL_TASK, - PICK_PLAN, - PLAN, - PREVIOUS_FAILED, SIMPLE_TEST, - TEST_PLANS, - USER_REQ, +) +from vision_agent.agent.vision_agent_planner import ( + AnthropicVisionAgentPlanner, + AzureVisionAgentPlanner, + OllamaVisionAgentPlanner, + OpenAIVisionAgentPlanner, + PlanContext, ) from vision_agent.lmm import ( LMM, @@ -40,241 +41,11 @@ from vision_agent.tools.meta_tools import get_diff from vision_agent.utils import CodeInterpreterFactory, Execution from vision_agent.utils.execute import CodeInterpreter -from vision_agent.utils.image_utils import b64_to_pil -from vision_agent.utils.sim import AzureSim, OllamaSim, Sim -from vision_agent.utils.video import play_video logging.basicConfig(stream=sys.stdout) WORKSPACE = Path(os.getenv("WORKSPACE", "")) _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 -_CONSOLE = Console() - - -class DefaultImports: - """Container for default imports used in the code execution.""" - - common_imports = [ - "import os", - "import numpy as np", - "from vision_agent.tools import *", - "from typing import *", - "from pillow_heif import register_heif_opener", - "register_heif_opener()", - ] - - @staticmethod - def to_code_string() -> str: - return "\n".join(DefaultImports.common_imports + T.__new_tools__) - - @staticmethod - def prepend_imports(code: str) -> str: - """Run this method to prepend the default imports to the code. - NOTE: be sure to run this method after the custom tools have been registered. - """ - return DefaultImports.to_code_string() + "\n\n" + code - - -def format_memory(memory: List[Dict[str, str]]) -> str: - output_str = "" - for i, m in enumerate(memory): - output_str += f"### Feedback {i}:\n" - output_str += f"Code {i}:\n```python\n{m['code']}```\n\n" - output_str += f"Feedback {i}: {m['feedback']}\n\n" - if "edits" in m: - output_str += f"Edits {i}:\n{m['edits']}\n" - output_str += "\n" - - return output_str - - -def format_plans(plans: Dict[str, Any]) -> str: - plan_str = "" - for k, v in plans.items(): - plan_str += "\n" + f"{k}: {v['thoughts']}\n" - plan_str += " -" + "\n -".join([e for e in v["instructions"]]) - - return plan_str - - -def write_plans( - chat: List[Message], - tool_desc: str, - working_memory: str, - model: LMM, -) -> Dict[str, Any]: - chat = copy.deepcopy(chat) - if chat[-1]["role"] != "user": - raise ValueError("Last chat message must be from the user.") - - user_request = chat[-1]["content"] - context = USER_REQ.format(user_request=user_request) - prompt = PLAN.format( - context=context, - tool_desc=tool_desc, - feedback=working_memory, - ) - chat[-1]["content"] = prompt - return extract_json(model(chat, stream=False)) # type: ignore - - -def pick_plan( - chat: List[Message], - plans: Dict[str, Any], - tool_info: str, - model: LMM, - code_interpreter: CodeInterpreter, - media: List[str], - log_progress: Callable[[Dict[str, Any]], None], - verbosity: int = 0, - max_retries: int = 3, -) -> Tuple[Dict[str, str], str]: - log_progress( - { - "type": "log", - "log_content": "Generating code to pick the best plan", - "status": "started", - } - ) - - chat = copy.deepcopy(chat) - if chat[-1]["role"] != "user": - raise ValueError("Last chat message must be from the user.") - - plan_str = format_plans(plans) - prompt = TEST_PLANS.format( - docstring=tool_info, plans=plan_str, previous_attempts="", media=media - ) - - code = extract_code(model(prompt, stream=False)) # type: ignore - log_progress( - { - "type": "log", - "log_content": "Executing code to test plans", - "code": DefaultImports.prepend_imports(code), - "status": "running", - } - ) - tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code)) - # Because of the way we trace function calls the trace information ends up in the - # results. We don't want to show this info to the LLM so we don't include it in the - # tool_output_str. - tool_output_str = tool_output.text(include_results=False).strip() - - if verbosity == 2: - _print_code("Initial code and tests:", code) - _LOGGER.info(f"Initial code execution result:\n{tool_output_str}") - - log_progress( - { - "type": "log", - "log_content": ( - "Code execution succeeded" - if tool_output.success - else "Code execution failed" - ), - "code": DefaultImports.prepend_imports(code), - # "payload": tool_output.to_json(), - "status": "completed" if tool_output.success else "failed", - } - ) - - # retry if the tool output is empty or code fails - count = 0 - while ( - not tool_output.success - or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0) - ) and count < max_retries: - prompt = TEST_PLANS.format( - docstring=tool_info, - plans=plan_str, - previous_attempts=PREVIOUS_FAILED.format( - code=code, error="\n".join(tool_output_str.splitlines()[-50:]) - ), - media=media, - ) - log_progress( - { - "type": "log", - "log_content": "Retrying code to test plans", - "status": "running", - "code": DefaultImports.prepend_imports(code), - } - ) - code = extract_code(model(prompt, stream=False)) # type: ignore - tool_output = code_interpreter.exec_isolation( - DefaultImports.prepend_imports(code) - ) - log_progress( - { - "type": "log", - "log_content": ( - "Code execution succeeded" - if tool_output.success - else "Code execution failed" - ), - "code": DefaultImports.prepend_imports(code), - # "payload": tool_output.to_json(), - "status": "completed" if tool_output.success else "failed", - } - ) - tool_output_str = tool_output.text(include_results=False).strip() - - if verbosity == 2: - _print_code("Code and test after attempted fix:", code) - _LOGGER.info(f"Code execution result after attempt {count + 1}") - _LOGGER.info(f"{tool_output_str}") - - count += 1 - - if verbosity >= 1: - _print_code("Final code:", code) - - user_req = chat[-1]["content"] - context = USER_REQ.format(user_request=user_req) - # because the tool picker model gets the image as well, we have to be careful with - # how much text we send it, so we truncate the tool output to 20,000 characters - prompt = PICK_PLAN.format( - context=context, - plans=format_plans(plans), - tool_output=tool_output_str[:20_000], - ) - chat[-1]["content"] = prompt - - count = 0 - plan_thoughts = None - while plan_thoughts is None and count < max_retries: - try: - plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore - except JSONDecodeError as e: - _LOGGER.exception( - f"Error while extracting JSON during picking best plan {str(e)}" - ) - pass - count += 1 - - if ( - plan_thoughts is None - or "best_plan" not in plan_thoughts - or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans) - ): - _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}") - plan_thoughts = {"best_plan": list(plans.keys())[0]} - - if "thoughts" not in plan_thoughts: - plan_thoughts["thoughts"] = "" - - if verbosity >= 1: - _LOGGER.info(f"Best plan:\n{plan_thoughts}") - log_progress( - { - "type": "log", - "log_content": "Picked best plan", - "status": "completed", - "payload": plans[plan_thoughts["best_plan"]], - } - ) - return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str def write_code( @@ -393,7 +164,7 @@ def write_and_test_code( } ) if verbosity == 2: - _print_code("Initial code and tests:", code, test) + print_code("Initial code and tests:", code, test) _LOGGER.info( f"Initial code execution result:\n{result.text(include_logs=True)}" ) @@ -418,7 +189,7 @@ def write_and_test_code( count += 1 if verbosity >= 1: - _print_code("Final code and tests:", code, test) + print_code("Final code and tests:", code, test) return { "code": code, @@ -537,7 +308,7 @@ def debug_code( } ) if verbosity == 2: - _print_code("Code and test after attempted fix:", code, test) + print_code("Code and test after attempted fix:", code, test) _LOGGER.info( f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}" ) @@ -545,62 +316,6 @@ def debug_code( return code, test, result -def _print_code(title: str, code: str, test: Optional[str] = None) -> None: - _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True)) - _CONSOLE.print("=" * 30 + " Code " + "=" * 30) - _CONSOLE.print( - Syntax( - DefaultImports.prepend_imports(code), - "python", - theme="gruvbox-dark", - line_numbers=True, - ) - ) - if test: - _CONSOLE.print("=" * 30 + " Test " + "=" * 30) - _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True)) - - -def retrieve_tools( - plans: Dict[str, Dict[str, Any]], - tool_recommender: Sim, - log_progress: Callable[[Dict[str, Any]], None], - verbosity: int = 0, -) -> Dict[str, str]: - log_progress( - { - "type": "log", - "log_content": ("Retrieving tools for each plan"), - "status": "started", - } - ) - tool_info = [] - tool_desc = [] - tool_lists: Dict[str, List[Dict[str, str]]] = {} - for k, plan in plans.items(): - tool_lists[k] = [] - for task in plan["instructions"]: - tools = tool_recommender.top_k(task, k=2, thresh=0.3) - tool_info.extend([e["doc"] for e in tools]) - tool_desc.extend([e["desc"] for e in tools]) - tool_lists[k].extend( - {"description": e["desc"], "documentation": e["doc"]} for e in tools - ) - - if verbosity == 2: - tool_desc_str = "\n".join(set(tool_desc)) - _LOGGER.info(f"Tools Description:\n{tool_desc_str}") - - tool_lists_unique = {} - for k in tool_lists: - tool_lists_unique[k] = "\n\n".join( - set(e["documentation"] for e in tool_lists[k]) - ) - all_tools = "\n\n".join(set(tool_info)) - tool_lists_unique["all"] = all_tools - return tool_lists_unique - - class VisionAgentCoder(Agent): """Vision Agent Coder is an agentic framework that can output code based on a user request. It can plan tasks, retrieve relevant tools, write code, write tests and @@ -616,23 +331,22 @@ class VisionAgentCoder(Agent): def __init__( self, - planner: Optional[LMM] = None, + planner: Optional[Agent] = None, coder: Optional[LMM] = None, tester: Optional[LMM] = None, debugger: Optional[LMM] = None, - tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: """Initialize the Vision Agent Coder. Parameters: - planner (Optional[LMM]): The planner model to use. Defaults to AnthropicLMM. + planner (Optional[Agent]): The planner model to use. Defaults to + AnthropicVisionAgentPlanner. coder (Optional[LMM]): The coder model to use. Defaults to AnthropicLMM. tester (Optional[LMM]): The tester model to use. Defaults to AnthropicLMM. debugger (Optional[LMM]): The debugger model to use. Defaults to AnthropicLMM. - tool_recommender (Optional[Sim]): The tool recommender model to use. verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the highest verbosity level which will output all intermediate debugging code. @@ -641,14 +355,17 @@ def __init__( in a web application where multiple VisionAgentCoder instances are running in parallel. This callback ensures that the progress are not mixed up. - code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A - code sandbox is used to run the generated code. It can be one of the - following values: None, "local" or "e2b". If None, VisionAgentCoder - will read the value from the environment variable CODE_SANDBOX_RUNTIME. - If it's also None, the local python runtime environment will be used. + code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values + it can be one of: None, "local" or "e2b". If None, it will read from + the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter + object is provided it will use that. """ - self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner + self.planner = ( + AnthropicVisionAgentPlanner(verbosity=verbosity) + if planner is None + else planner + ) self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger @@ -656,21 +373,15 @@ def __init__( if self.verbosity > 0: _LOGGER.setLevel(logging.INFO) - self.tool_recommender = ( - Sim(T.TOOLS_DF, sim_key="desc") - if tool_recommender is None - else tool_recommender - ) self.report_progress_callback = report_progress_callback - self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = code_interpreter def __call__( self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None, ) -> str: - """Chat with VisionAgentCoder and return intermediate information regarding the - task. + """Generate code based on a user request. Parameters: input (Union[str, List[Message]]): A conversation in the format of @@ -686,46 +397,58 @@ def __call__( input = [{"role": "user", "content": input}] if media is not None: input[0]["media"] = [media] - results = self.chat_with_workflow(input) - results.pop("working_memory") - return results["code"] # type: ignore + code_and_context = self.generate_code(input) + return code_and_context["code"] # type: ignore - def chat_with_workflow( + def generate_code_from_plan( self, chat: List[Message], - test_multi_plan: bool = True, - display_visualization: bool = False, - custom_tool_names: Optional[List[str]] = None, + plan_context: PlanContext, + code_interpreter: Optional[CodeInterpreter] = None, ) -> Dict[str, Any]: - """Chat with VisionAgentCoder and return intermediate information regarding the - task. + """Generates code and other intermediate outputs from a chat input and a plan. + The plan includes: + - plans: The plans generated by the planner. + - best_plan: The best plan selected by the planner. + - plan_thoughts: The thoughts of the planner, including any modifications + to the plan. + - tool_doc: The tool documentation for the best plan. + - tool_output: The tool output from the tools used by the best plan. Parameters: - chat (List[Message]): A conversation - in the format of: - [{"role": "user", "content": "describe your task here..."}] - or if it contains media files, it should be in the format of: - [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}] - test_multi_plan (bool): If True, it will test tools for multiple plans and - pick the best one based off of the tool results. If False, it will go - with the first plan. - display_visualization (bool): If True, it opens a new window locally to - show the image(s) created by visualization code (if there is any). - custom_tool_names (List[str]): A list of custom tools for the agent to pick - and use. If not provided, default to full tool set from vision_agent.tools. + chat (List[Message]): A conversation in the format of + [{"role": "user", "content": "describe your task here..."}]. + plan_context (PlanContext): The context of the plan, including the plans, + best_plan, plan_thoughts, tool_doc, and tool_output. + test_multi_plan (bool): Whether to test multiple plans or just the best plan. + custom_tool_names (Optional[List[str]]): A list of custom tool names to use + for the planner. Returns: - Dict[str, Any]: A dictionary containing the code, test, test result, plan, - and working memory of the agent. + Dict[str, Any]: A dictionary containing the code output by the + VisionAgentCoder and other intermediate outputs. include: + - status (str): Whether or not the agent completed or failed generating + the code. + - code (str): The code output by the VisionAgentCoder. + - test (str): The test output by the VisionAgentCoder. + - test_result (Execution): The result of the test execution. + - plans (Dict[str, Any]): The plans generated by the planner. + - plan_thoughts (str): The thoughts of the planner. + - working_memory (List[Dict[str, str]]): The working memory of the agent. """ - if not chat: raise ValueError("Chat cannot be empty.") # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues - with CodeInterpreterFactory.new_instance( - code_sandbox_runtime=self.code_sandbox_runtime - ) as code_interpreter: + code_interpreter = ( + self.code_interpreter + if self.code_interpreter is not None + and not isinstance(self.code_interpreter, str) + else CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_interpreter, + ) + ) + with code_interpreter: chat = copy.deepcopy(chat) media_list = [] for chat_i in chat: @@ -759,74 +482,22 @@ def chat_with_workflow( code = "" test = "" working_memory: List[Dict[str, str]] = [] - results = {"code": "", "test": "", "plan": []} - plan = [] - success = False - - plans = self._create_plans( - int_chat, custom_tool_names, working_memory, self.planner - ) - - if test_multi_plan: - self._log_plans(plans, self.verbosity) - - tool_infos = retrieve_tools( - plans, - self.tool_recommender, - self.log_progress, - self.verbosity, - ) - - if test_multi_plan: - plan_thoughts, tool_output_str = pick_plan( - int_chat, - plans, - tool_infos["all"], - self.coder, - code_interpreter, - media_list, - self.log_progress, - verbosity=self.verbosity, - ) - best_plan = plan_thoughts["best_plan"] - plan_thoughts_str = plan_thoughts["thoughts"] - else: - best_plan = list(plans.keys())[0] - tool_output_str = "" - plan_thoughts_str = "" - - if best_plan in plans and best_plan in tool_infos: - plan_i = plans[best_plan] - tool_info = tool_infos[best_plan] - else: - if self.verbosity >= 1: - _LOGGER.warning( - f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." - ) - k = list(plans.keys())[0] - plan_i = plans[k] - tool_info = tool_infos[k] - - self.log_progress( - { - "type": "log", - "log_content": "Creating plans", - "status": "completed", - "payload": tool_info, - } - ) + plan = plan_context.plans[plan_context.best_plan] + tool_doc = plan_context.tool_doc + tool_output_str = plan_context.tool_output + plan_thoughts_str = str(plan_context.plan_thoughts) if self.verbosity >= 1: - plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]] + plan_fixed = [{"instructions": e} for e in plan["instructions"]] _LOGGER.info( - f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + f"Picked best plan:\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) results = write_and_test_code( chat=[{"role": c["role"], "content": c["content"]} for c in int_chat], - plan=f"\n{plan_i['thoughts']}\n-" - + "\n-".join([e for e in plan_i["instructions"]]), - tool_info=tool_info, + plan=f"\n{plan['thoughts']}\n-" + + "\n-".join([e for e in plan["instructions"]]), + tool_info=tool_doc, tool_output=tool_output_str, plan_thoughts=plan_thoughts_str, tool_utils=T.UTILITIES_DOCSTRING, @@ -842,64 +513,83 @@ def chat_with_workflow( success = cast(bool, results["success"]) code = remove_installs_from_code(cast(str, results["code"])) test = remove_installs_from_code(cast(str, results["test"])) - working_memory.extend(results["working_memory"]) # type: ignore - plan.append({"code": code, "test": test, "plan": plan_i}) + working_memory.extend(results["working_memory"]) execution_result = cast(Execution, results["test_result"]) - if display_visualization: - for res in execution_result.results: - if res.png: - b64_to_pil(res.png).show() - if res.mp4: - play_video(res.mp4) - return { "status": "completed" if success else "failed", "code": DefaultImports.prepend_imports(code), "test": test, "test_result": execution_result, - "plans": plans, + "plans": plan_context.plans, "plan_thoughts": plan_thoughts_str, "working_memory": working_memory, } - def log_progress(self, data: Dict[str, Any]) -> None: - if self.report_progress_callback is not None: - self.report_progress_callback(data) - - def _create_plans( + def generate_code( self, - int_chat: List[Message], - customized_tool_names: Optional[List[str]], - working_memory: List[Dict[str, str]], - planner: LMM, + chat: List[Message], + test_multi_plan: bool = True, + custom_tool_names: Optional[List[str]] = None, ) -> Dict[str, Any]: - self.log_progress( - { - "type": "log", - "log_content": "Creating plans", - "status": "started", - } - ) - plans = write_plans( - int_chat, - T.get_tool_descriptions_by_names( - customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore - ), - format_memory(working_memory), - planner, + """Generates code and other intermediate outputs from a chat input. + + Parameters: + chat (List[Message]): A conversation in the format of + [{"role": "user", "content": "describe your task here..."}]. + test_multi_plan (bool): Whether to test multiple plans or just the best plan. + custom_tool_names (Optional[List[str]]): A list of custom tool names to use + for the planner. + + Returns: + Dict[str, Any]: A dictionary containing the code output by the + VisionAgentCoder and other intermediate outputs. include: + - status (str): Whether or not the agent completed or failed generating + the code. + - code (str): The code output by the VisionAgentCoder. + - test (str): The test output by the VisionAgentCoder. + - test_result (Execution): The result of the test execution. + - plans (Dict[str, Any]): The plans generated by the planner. + - plan_thoughts (str): The thoughts of the planner. + - working_memory (List[Dict[str, str]]): The working memory of the agent. + """ + if not chat: + raise ValueError("Chat cannot be empty.") + + # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues + code_interpreter = ( + self.code_interpreter + if self.code_interpreter is not None + and not isinstance(self.code_interpreter, str) + else CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_interpreter, + ) ) - return plans + with code_interpreter: + plan_context = self.planner.generate_plan( # type: ignore + chat, + test_multi_plan=test_multi_plan, + custom_tool_names=custom_tool_names, + code_interpreter=code_interpreter, + ) - def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None: - if verbosity >= 1: - for p in plans: - # tabulate will fail if the keys are not the same for all elements - p_fixed = [{"instructions": e} for e in plans[p]["instructions"]] - _LOGGER.info( - f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" - ) + code_and_context = self.generate_code_from_plan( + chat, + plan_context, + code_interpreter=code_interpreter, + ) + return code_and_context + + def chat(self, chat: List[Message]) -> List[Message]: + chat = copy.deepcopy(chat) + code = self.generate_code(chat) + chat.append({"role": "agent", "content": code["code"]}) + return chat + + def log_progress(self, data: Dict[str, Any]) -> None: + if self.report_progress_callback is not None: + self.report_progress_callback(data) class OpenAIVisionAgentCoder(VisionAgentCoder): @@ -907,17 +597,18 @@ class OpenAIVisionAgentCoder(VisionAgentCoder): def __init__( self, - planner: Optional[LMM] = None, + planner: Optional[Agent] = None, coder: Optional[LMM] = None, tester: Optional[LMM] = None, debugger: Optional[LMM] = None, - tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: self.planner = ( - OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner + OpenAIVisionAgentPlanner(verbosity=verbosity) + if planner is None + else planner ) self.coder = OpenAILMM(temperature=0.0) if coder is None else coder self.tester = OpenAILMM(temperature=0.0) if tester is None else tester @@ -926,13 +617,8 @@ def __init__( if self.verbosity > 0: _LOGGER.setLevel(logging.INFO) - self.tool_recommender = ( - Sim(T.TOOLS_DF, sim_key="desc") - if tool_recommender is None - else tool_recommender - ) self.report_progress_callback = report_progress_callback - self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = code_interpreter class AnthropicVisionAgentCoder(VisionAgentCoder): @@ -940,17 +626,20 @@ class AnthropicVisionAgentCoder(VisionAgentCoder): def __init__( self, - planner: Optional[LMM] = None, + planner: Optional[Agent] = None, coder: Optional[LMM] = None, tester: Optional[LMM] = None, debugger: Optional[LMM] = None, - tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: # NOTE: Claude doesn't have an official JSON mode - self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner + self.planner = ( + AnthropicVisionAgentPlanner(verbosity=verbosity) + if planner is None + else planner + ) self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger @@ -958,15 +647,8 @@ def __init__( if self.verbosity > 0: _LOGGER.setLevel(logging.INFO) - # Anthropic does not offer any embedding models and instead recomends Voyage, - # we're using OpenAI's embedder for now. - self.tool_recommender = ( - Sim(T.TOOLS_DF, sim_key="desc") - if tool_recommender is None - else tool_recommender - ) self.report_progress_callback = report_progress_callback - self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = code_interpreter class OllamaVisionAgentCoder(VisionAgentCoder): @@ -988,17 +670,17 @@ class OllamaVisionAgentCoder(VisionAgentCoder): def __init__( self, - planner: Optional[LMM] = None, + planner: Optional[Agent] = None, coder: Optional[LMM] = None, tester: Optional[LMM] = None, debugger: Optional[LMM] = None, - tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: super().__init__( planner=( - OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True) + OllamaVisionAgentPlanner(verbosity=verbosity) if planner is None else planner ), @@ -1017,13 +699,9 @@ def __init__( if debugger is None else debugger ), - tool_recommender=( - OllamaSim(T.TOOLS_DF, sim_key="desc") - if tool_recommender is None - else tool_recommender - ), verbosity=verbosity, report_progress_callback=report_progress_callback, + code_interpreter=code_interpreter, ) @@ -1043,22 +721,22 @@ class AzureVisionAgentCoder(VisionAgentCoder): def __init__( self, - planner: Optional[LMM] = None, + planner: Optional[Agent] = None, coder: Optional[LMM] = None, tester: Optional[LMM] = None, debugger: Optional[LMM] = None, - tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: """Initialize the Vision Agent Coder. Parameters: - planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM. + planner (Optional[Agent]): The planner model to use. Defaults to + AzureVisionAgentPlanner. coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM. tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM. debugger (Optional[LMM]): The debugger model to - tool_recommender (Optional[Sim]): The tool recommender model to use. verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the highest verbosity level which will output all intermediate debugging code. @@ -1069,7 +747,7 @@ def __init__( """ super().__init__( planner=( - AzureOpenAILMM(temperature=0.0, json_mode=True) + AzureVisionAgentPlanner(verbosity=verbosity) if planner is None else planner ), @@ -1078,11 +756,7 @@ def __init__( debugger=( AzureOpenAILMM(temperature=0.0) if debugger is None else debugger ), - tool_recommender=( - AzureSim(T.TOOLS_DF, sim_key="desc") - if tool_recommender is None - else tool_recommender - ), verbosity=verbosity, report_progress_callback=report_progress_callback, + code_interpreter=code_interpreter, ) diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index 07f2c6e2..66eb4c29 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -1,8 +1,3 @@ -USER_REQ = """ -## User Request -{user_request} -""" - FULL_TASK = """ ## User Request {user_request} @@ -18,204 +13,6 @@ """ -PLAN = """ -**Context**: -{context} - -**Tools Available**: -{tool_desc} - -**Previous Feedback**: -{feedback} - -**Instructions**: -1. Based on the context and tools you have available, create a plan of subtasks to achieve the user request. -2. For each subtask, be sure to include the tool(s) you want to use to accomplish that subtask. -3. Output three different plans each utilize a different strategy or set of tools ordering them from most likely to least likely to succeed. - -Output a list of jsons in the following format: - -```json -{{ - "plan1": - {{ - "thoughts": str # your thought process for choosing this plan - "instructions": [ - str # what you should do in this task associated with a tool - ] - }}, - "plan2": ..., - "plan3": ... -}} -``` -""" - - -TEST_PLANS = """ -**Role**: You are a software programmer responsible for testing different tools. - -**Task**: Your responsibility is to take a set of several plans and test the different tools for each plan. - -**Documentation**: -This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`. - -{docstring} - -**Plans**: -{plans} - -**Previous Attempts**: -{previous_attempts} - -**Examples**: ---- EXAMPLE1 --- -plan1: -- Load the image from the provided file path 'image.jpg'. -- Use the 'owl_v2_image' tool with the prompt 'person' to detect and count the number of people in the image. -plan2: -- Load the image from the provided file path 'image.jpg'. -- Use the 'florence2_sam2_image' tool with the prompt 'person' to detect and count the number of people in the image. -- Count the number of detected objects labeled as 'person'. -plan3: -- Load the image from the provided file path 'image.jpg'. -- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people. - -```python -from vision_agent.tools import load_image, owl_v2_image, florence2_sam2_image, countgd_counting -image = load_image("image.jpg") -owl_v2_out = owl_v2_image("person", image) - -f2s2_out = florence2_sam2_image("person", image) -# strip out the masks from the output becuase they don't provide useful information when printed -f2s2_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in f2s2_out] - -cgd_out = countgd_counting(image) - -final_out = {{"owl_v2_image": owl_v2_out, "florence2_sam2_image": f2s2, "countgd_counting": cgd_out}} -print(final_out) ---- END EXAMPLE1 --- - ---- EXAMPLE2 --- -plan1: -- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool. -- Use the 'owl_v2_video' tool with the prompt 'person' to detect where the people are in the video. -plan2: -- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool. -- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video. -plan3: -- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool. -- Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video. - - -```python -import numpy as np -from vision_agent.tools import extract_frames_and_timestamps, owl_v2_video, florence2_phrase_grounding, florence2_sam2_video_tracking - -# sample at 1 FPS and use the first 10 frames to reduce processing time -frames = extract_frames_and_timestamps("video.mp4", 1) -frames = [f["frame"] for f in frames][:10] - -# strip arrays from the output to make it easier to read -def remove_arrays(o): - if isinstance(o, list): - return [remove_arrays(e) for e in o] - elif isinstance(o, dict): - return {{k: remove_arrays(v) for k, v in o.items()}} - elif isinstance(o, np.ndarray): - return "array: " + str(o.shape) - else: - return o - -# return the counts of each label per frame to help determine the stability of the model results -def get_counts(preds): - counts = {{}} - for i, pred_frame in enumerate(preds): - counts_i = {{}} - for pred in pred_frame: - label = pred["label"].split(":")[1] if ":" in pred["label"] else pred["label"] - counts_i[label] = counts_i.get(label, 0) + 1 - counts[f"frame_{{i}}"] = counts_i - return counts - - -# plan1 -owl_v2_out = owl_v2_video("person", frames) -owl_v2_counts = get_counts(owl_v2_out) - -# plan2 -florence2_out = [florence2_phrase_grounding("person", f) for f in frames] -florence2_counts = get_counts(florence2_out) - -# plan3 -f2s2_tracking_out = florence2_sam2_video_tracking("person", frames) -remove_arrays(f2s2_tracking_out) -f2s2_counts = get_counts(f2s2_tracking_out) - -final_out = {{ - "owl_v2_video": owl_v2_out, - "florence2_phrase_grounding": florence2_out, - "florence2_sam2_video_tracking": f2s2_out, -}} - -counts = {{ - "owl_v2_video": owl_v2_counts, - "florence2_phrase_grounding": florence2_counts, - "florence2_sam2_video_tracking": f2s2_counts, -}} - -print(final_out) -print(labels_and_scores) -print(counts) -``` ---- END EXAMPLE2 --- - -**Instructions**: -1. Write a program to load the media and call each tool and print it's output along with other relevant information. -2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary. -3. Your test case MUST run only on the given images which are {media} -4. Print this final dictionary. -5. For video input, sample at 1 FPS and use the first 10 frames only to reduce processing time. -""" - - -PREVIOUS_FAILED = """ -**Previous Failed Attempts**: -You previously ran this code: -```python -{code} -``` - -But got the following error or no stdout: -{error} -""" - - -PICK_PLAN = """ -**Role**: You are an advanced AI model that can understand the user request and construct plans to accomplish it. - -**Task**: Your responsibility is to pick the best plan from the three plans provided. - -**Context**: -{context} - -**Plans**: -{plans} - -**Tool Output**: -{tool_output} - -**Instructions**: -1. Re-read the user request, plans, tool outputs and examine the image. -2. Solve the problem yourself given the image and pick the most accurate plan that matches your solution the best. -3. Add modifications to improve the plan including: changing a tool, adding thresholds, string matching. -3. Output a JSON object with the following format: -{{ - "predicted_answer": str # the answer you would expect from the best plan - "thoughts": str # your thought process for choosing the best plan over other plans and any modifications you made - "best_plan": str # the best plan you have chosen, must be `plan1`, `plan2`, or `plan3` -}} -""" - CODE = """ **Role**: You are a software programmer. diff --git a/vision_agent/agent/vision_agent_planner.py b/vision_agent/agent/vision_agent_planner.py new file mode 100644 index 00000000..bb7ac3ba --- /dev/null +++ b/vision_agent/agent/vision_agent_planner.py @@ -0,0 +1,553 @@ +import copy +import logging +from json import JSONDecodeError +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +from pydantic import BaseModel + +import vision_agent.tools as T +from vision_agent.agent import Agent +from vision_agent.agent.agent_utils import ( + DefaultImports, + extract_code, + extract_json, + format_memory, + format_plans, + print_code, +) +from vision_agent.agent.vision_agent_planner_prompts import ( + PICK_PLAN, + PLAN, + PREVIOUS_FAILED, + TEST_PLANS, + USER_REQ, +) +from vision_agent.lmm import ( + LMM, + AnthropicLMM, + AzureOpenAILMM, + Message, + OllamaLMM, + OpenAILMM, +) +from vision_agent.utils.execute import ( + CodeInterpreter, + CodeInterpreterFactory, + Execution, +) +from vision_agent.utils.sim import AzureSim, OllamaSim, Sim + +_LOGGER = logging.getLogger(__name__) + + +class PlanContext(BaseModel): + plans: Dict[str, Dict[str, Union[str, List[str]]]] + best_plan: str + plan_thoughts: str + tool_output: str + tool_doc: str + test_results: Optional[Execution] + + +def retrieve_tools( + plans: Dict[str, Dict[str, Any]], + tool_recommender: Sim, + log_progress: Callable[[Dict[str, Any]], None], + verbosity: int = 0, +) -> Dict[str, str]: + log_progress( + { + "type": "log", + "log_content": ("Retrieving tools for each plan"), + "status": "started", + } + ) + tool_info = [] + tool_desc = [] + tool_lists: Dict[str, List[Dict[str, str]]] = {} + for k, plan in plans.items(): + tool_lists[k] = [] + for task in plan["instructions"]: + tools = tool_recommender.top_k(task, k=2, thresh=0.3) + tool_info.extend([e["doc"] for e in tools]) + tool_desc.extend([e["desc"] for e in tools]) + tool_lists[k].extend( + {"description": e["desc"], "documentation": e["doc"]} for e in tools + ) + + if verbosity == 2: + tool_desc_str = "\n".join(set(tool_desc)) + _LOGGER.info(f"Tools Description:\n{tool_desc_str}") + + tool_lists_unique = {} + for k in tool_lists: + tool_lists_unique[k] = "\n\n".join( + set(e["documentation"] for e in tool_lists[k]) + ) + all_tools = "\n\n".join(set(tool_info)) + tool_lists_unique["all"] = all_tools + return tool_lists_unique + + +def write_plans( + chat: List[Message], tool_desc: str, working_memory: str, model: LMM +) -> Dict[str, Any]: + chat = copy.deepcopy(chat) + if chat[-1]["role"] != "user": + raise ValueError("Last message in chat must be from user") + + user_request = chat[-1]["content"] + context = USER_REQ.format(user_request=user_request) + prompt = PLAN.format( + context=context, + tool_desc=tool_desc, + feedback=working_memory, + ) + chat[-1]["content"] = prompt + return extract_json(model(chat, stream=False)) # type: ignore + + +def write_and_exec_plan_tests( + plans: Dict[str, Any], + tool_info: str, + media: List[str], + model: LMM, + log_progress: Callable[[Dict[str, Any]], None], + code_interpreter: CodeInterpreter, + verbosity: int = 0, + max_retries: int = 3, +) -> Tuple[str, Execution]: + + plan_str = format_plans(plans) + prompt = TEST_PLANS.format( + docstring=tool_info, plans=plan_str, previous_attempts="", media=media + ) + + code = extract_code(model(prompt, stream=False)) # type: ignore + log_progress( + { + "type": "log", + "log_content": "Executing code to test plans", + "code": DefaultImports.prepend_imports(code), + "status": "running", + } + ) + tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code)) + # Because of the way we trace function calls the trace information ends up in the + # results. We don't want to show this info to the LLM so we don't include it in the + # tool_output_str. + tool_output_str = tool_output.text(include_results=False).strip() + + if verbosity == 2: + print_code("Initial code and tests:", code) + _LOGGER.info(f"Initial code execution result:\n{tool_output_str}") + + log_progress( + { + "type": "log", + "log_content": ( + "Code execution succeeded" + if tool_output.success + else "Code execution failed" + ), + "code": DefaultImports.prepend_imports(code), + # "payload": tool_output.to_json(), + "status": "completed" if tool_output.success else "failed", + } + ) + + # retry if the tool output is empty or code fails + count = 0 + tool_output_str = tool_output.text(include_results=False).strip() + while ( + not tool_output.success + or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0) + ) and count < max_retries: + prompt = TEST_PLANS.format( + docstring=tool_info, + plans=plan_str, + previous_attempts=PREVIOUS_FAILED.format( + code=code, error="\n".join(tool_output_str.splitlines()[-50:]) + ), + media=media, + ) + log_progress( + { + "type": "log", + "log_content": "Retrying code to test plans", + "status": "running", + "code": DefaultImports.prepend_imports(code), + } + ) + code = extract_code(model(prompt, stream=False)) # type: ignore + tool_output = code_interpreter.exec_isolation( + DefaultImports.prepend_imports(code) + ) + log_progress( + { + "type": "log", + "log_content": ( + "Code execution succeeded" + if tool_output.success + else "Code execution failed" + ), + "code": DefaultImports.prepend_imports(code), + # "payload": tool_output.to_json(), + "status": "completed" if tool_output.success else "failed", + } + ) + tool_output_str = tool_output.text(include_results=False).strip() + + if verbosity == 2: + print_code("Code and test after attempted fix:", code) + _LOGGER.info(f"Code execution result after attempt {count + 1}") + _LOGGER.info(f"{tool_output_str}") + + count += 1 + + return code, tool_output + + +def write_plan_thoughts( + chat: List[Message], + plans: Dict[str, Any], + tool_output_str: str, + model: LMM, + max_retries: int = 3, +) -> Dict[str, str]: + user_req = chat[-1]["content"] + context = USER_REQ.format(user_request=user_req) + # because the tool picker model gets the image as well, we have to be careful with + # how much text we send it, so we truncate the tool output to 20,000 characters + prompt = PICK_PLAN.format( + context=context, + plans=format_plans(plans), + tool_output=tool_output_str[:20_000], + ) + chat[-1]["content"] = prompt + count = 0 + + plan_thoughts = None + while plan_thoughts is None and count < max_retries: + try: + plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore + except JSONDecodeError as e: + _LOGGER.exception( + f"Error while extracting JSON during picking best plan {str(e)}" + ) + pass + count += 1 + + if ( + plan_thoughts is None + or "best_plan" not in plan_thoughts + or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans) + ): + _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}") + plan_thoughts = {"best_plan": list(plans.keys())[0]} + + if "thoughts" not in plan_thoughts: + plan_thoughts["thoughts"] = "" + return plan_thoughts + + +def pick_plan( + chat: List[Message], + plans: Dict[str, Any], + tool_info: str, + model: LMM, + code_interpreter: CodeInterpreter, + media: List[str], + log_progress: Callable[[Dict[str, Any]], None], + verbosity: int = 0, + max_retries: int = 3, +) -> Tuple[Dict[str, str], str, Execution]: + log_progress( + { + "type": "log", + "log_content": "Generating code to pick the best plan", + "status": "started", + } + ) + + chat = copy.deepcopy(chat) + if chat[-1]["role"] != "user": + raise ValueError("Last chat message must be from the user.") + + code, tool_output = write_and_exec_plan_tests( + plans, + tool_info, + media, + model, + log_progress, + code_interpreter, + verbosity, + max_retries, + ) + + if verbosity >= 1: + print_code("Final code:", code) + + plan_thoughts = write_plan_thoughts( + chat, + plans, + tool_output.text(include_results=False).strip(), + model, + max_retries, + ) + + if verbosity >= 1: + _LOGGER.info(f"Best plan:\n{plan_thoughts}") + log_progress( + { + "type": "log", + "log_content": "Picked best plan", + "status": "completed", + "payload": plans[plan_thoughts["best_plan"]], + } + ) + # return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str + return plan_thoughts, code, tool_output + + +class VisionAgentPlanner(Agent): + def __init__( + self, + planner: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, + ) -> None: + self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner + self.verbosity = verbosity + if self.verbosity > 0: + _LOGGER.setLevel(logging.INFO) + + self.tool_recommender = ( + Sim(T.TOOLS_DF, sim_key="desc") + if tool_recommender is None + else tool_recommender + ) + self.report_progress_callback = report_progress_callback + self.code_interpreter = code_interpreter + + def __call__( + self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None + ) -> str: + if isinstance(input, str): + input = [{"role": "user", "content": input}] + if media is not None: + input[0]["media"] = [media] + planning_context = self.generate_plan(input) + return str(planning_context.plans[planning_context.best_plan]) + + def generate_plan( + self, + chat: List[Message], + test_multi_plan: bool = True, + custom_tool_names: Optional[List[str]] = None, + code_interpreter: Optional[CodeInterpreter] = None, + ) -> PlanContext: + if not chat: + raise ValueError("Chat cannot be empty") + + code_interpreter = ( + code_interpreter + if code_interpreter is not None + else ( + self.code_interpreter + if not isinstance(self.code_interpreter, str) + else CodeInterpreterFactory.new_instance(self.code_interpreter) + ) + ) + code_interpreter = cast(CodeInterpreter, code_interpreter) + with code_interpreter: + chat = copy.deepcopy(chat) + media_list = [] + for chat_i in chat: + if "media" in chat_i: + for media in chat_i["media"]: + media = ( + media + if type(media) is str + and media.startswith(("http", "https")) + else code_interpreter.upload_file(cast(str, media)) + ) + chat_i["content"] += f" Media name {media}" # type: ignore + media_list.append(str(media)) + + int_chat = cast( + List[Message], + [ + ( + { + "role": c["role"], + "content": c["content"], + "media": c["media"], + } + if "media" in c + else {"role": c["role"], "content": c["content"]} + ) + for c in chat + ], + ) + + working_memory: List[Dict[str, str]] = [] + + plans = write_plans( + chat, + T.get_tool_descriptions_by_names( + custom_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore + ), + format_memory(working_memory), + self.planner, + ) + + tool_docs = retrieve_tools( + plans, + self.tool_recommender, + self.log_progress, + self.verbosity, + ) + if test_multi_plan: + plan_thoughts, code, tool_output = pick_plan( + int_chat, + plans, + tool_docs["all"], + self.planner, + code_interpreter, + media_list, + self.log_progress, + self.verbosity, + ) + best_plan = plan_thoughts["best_plan"] + plan_thoughts_str = plan_thoughts["thoughts"] + tool_output_str = ( + "```python\n" + + code + + "\n```\n" + + tool_output.text(include_results=False).strip() + ) + else: + best_plan = list(plans.keys())[0] + tool_output_str = "" + plan_thoughts_str = "" + tool_output = None + + if best_plan in plans and best_plan in tool_docs: + tool_doc = tool_docs[best_plan] + else: + if self.verbosity >= 1: + _LOGGER.warning( + f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info." + ) + k = list(plans.keys())[0] + best_plan = k + tool_doc = tool_docs[k] + + return PlanContext( + plans=plans, + best_plan=best_plan, + plan_thoughts=plan_thoughts_str, + tool_output=tool_output_str, + test_results=tool_output, + tool_doc=tool_doc, + ) + + def log_progress(self, log: Dict[str, Any]) -> None: + if self.report_progress_callback is not None: + self.report_progress_callback(log) + + +class AnthropicVisionAgentPlanner(VisionAgentPlanner): + def __init__( + self, + planner: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, + ) -> None: + super().__init__( + planner=AnthropicLMM(temperature=0.0) if planner is None else planner, + tool_recommender=tool_recommender, + verbosity=verbosity, + report_progress_callback=report_progress_callback, + code_interpreter=code_interpreter, + ) + + +class OpenAIVisionAgentPlanner(VisionAgentPlanner): + def __init__( + self, + planner: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, + ) -> None: + super().__init__( + planner=( + OpenAILMM(temperature=0.0, json_mode=True) + if planner is None + else planner + ), + tool_recommender=tool_recommender, + verbosity=verbosity, + report_progress_callback=report_progress_callback, + code_interpreter=code_interpreter, + ) + + +class OllamaVisionAgentPlanner(VisionAgentPlanner): + def __init__( + self, + planner: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, + ) -> None: + super().__init__( + planner=( + OllamaLMM(model_name="llama3.1", temperature=0.0) + if planner is None + else planner + ), + tool_recommender=( + OllamaSim(T.TOOLS_DF, sim_key="desc") + if tool_recommender is None + else tool_recommender + ), + verbosity=verbosity, + report_progress_callback=report_progress_callback, + code_interpreter=code_interpreter, + ) + + +class AzureVisionAgentPlanner(VisionAgentPlanner): + def __init__( + self, + planner: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[Union[str, CodeInterpreter]] = None, + ) -> None: + super().__init__( + planner=( + AzureOpenAILMM(temperature=0.0, json_mode=True) + if planner is None + else planner + ), + tool_recommender=( + AzureSim(T.TOOLS_DF, sim_key="desc") + if tool_recommender is None + else tool_recommender + ), + verbosity=verbosity, + report_progress_callback=report_progress_callback, + code_interpreter=code_interpreter, + ) diff --git a/vision_agent/agent/vision_agent_planner_prompts.py b/vision_agent/agent/vision_agent_planner_prompts.py new file mode 100644 index 00000000..833e2c9b --- /dev/null +++ b/vision_agent/agent/vision_agent_planner_prompts.py @@ -0,0 +1,199 @@ +USER_REQ = """ +## User Request +{user_request} +""" + +PLAN = """ +**Context**: +{context} + +**Tools Available**: +{tool_desc} + +**Previous Feedback**: +{feedback} + +**Instructions**: +1. Based on the context and tools you have available, create a plan of subtasks to achieve the user request. +2. For each subtask, be sure to include the tool(s) you want to use to accomplish that subtask. +3. Output three different plans each utilize a different strategy or set of tools ordering them from most likely to least likely to succeed. + +Output a list of jsons in the following format: + +```json +{{ + "plan1": + {{ + "thoughts": str # your thought process for choosing this plan + "instructions": [ + str # what you should do in this task associated with a tool + ] + }}, + "plan2": ..., + "plan3": ... +}} +``` +""" + +TEST_PLANS = """ +**Role**: You are a software programmer responsible for testing different tools. + +**Task**: Your responsibility is to take a set of several plans and test the different tools for each plan. + +**Documentation**: +This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`. + +{docstring} + +**Plans**: +{plans} + +**Previous Attempts**: +{previous_attempts} + +**Examples**: +--- EXAMPLE1 --- +plan1: +- Load the image from the provided file path 'image.jpg'. +- Use the 'owl_v2_image' tool with the prompt 'person' to detect and count the number of people in the image. +plan2: +- Load the image from the provided file path 'image.jpg'. +- Use the 'florence2_sam2_image' tool with the prompt 'person' to detect and count the number of people in the image. +- Count the number of detected objects labeled as 'person'. +plan3: +- Load the image from the provided file path 'image.jpg'. +- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people. + +```python +from vision_agent.tools import load_image, owl_v2_image, florence2_sam2_image, countgd_counting +image = load_image("image.jpg") +owl_v2_out = owl_v2_image("person", image) + +f2s2_out = florence2_sam2_image("person", image) +# strip out the masks from the output becuase they don't provide useful information when printed +f2s2_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in f2s2_out] + +cgd_out = countgd_counting(image) + +final_out = {{"owl_v2_image": owl_v2_out, "florence2_sam2_image": f2s2, "countgd_counting": cgd_out}} +print(final_out) +--- END EXAMPLE1 --- + +--- EXAMPLE2 --- +plan1: +- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool. +- Use the 'owl_v2_video' tool with the prompt 'person' to detect where the people are in the video. +plan2: +- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool. +- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video. +plan3: +- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool. +- Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video. + + +```python +import numpy as np +from vision_agent.tools import extract_frames_and_timestamps, owl_v2_video, florence2_phrase_grounding, florence2_sam2_video_tracking + +# sample at 1 FPS and use the first 10 frames to reduce processing time +frames = extract_frames_and_timestamps("video.mp4", 1) +frames = [f["frame"] for f in frames][:10] + +# strip arrays from the output to make it easier to read +def remove_arrays(o): + if isinstance(o, list): + return [remove_arrays(e) for e in o] + elif isinstance(o, dict): + return {{k: remove_arrays(v) for k, v in o.items()}} + elif isinstance(o, np.ndarray): + return "array: " + str(o.shape) + else: + return o + +# return the counts of each label per frame to help determine the stability of the model results +def get_counts(preds): + counts = {{}} + for i, pred_frame in enumerate(preds): + counts_i = {{}} + for pred in pred_frame: + label = pred["label"].split(":")[1] if ":" in pred["label"] else pred["label"] + counts_i[label] = counts_i.get(label, 0) + 1 + counts[f"frame_{{i}}"] = counts_i + return counts + + +# plan1 +owl_v2_out = owl_v2_video("person", frames) +owl_v2_counts = get_counts(owl_v2_out) + +# plan2 +florence2_out = [florence2_phrase_grounding("person", f) for f in frames] +florence2_counts = get_counts(florence2_out) + +# plan3 +f2s2_tracking_out = florence2_sam2_video_tracking("person", frames) +remove_arrays(f2s2_tracking_out) +f2s2_counts = get_counts(f2s2_tracking_out) + +final_out = {{ + "owl_v2_video": owl_v2_out, + "florence2_phrase_grounding": florence2_out, + "florence2_sam2_video_tracking": f2s2_out, +}} + +counts = {{ + "owl_v2_video": owl_v2_counts, + "florence2_phrase_grounding": florence2_counts, + "florence2_sam2_video_tracking": f2s2_counts, +}} + +print(final_out) +print(labels_and_scores) +print(counts) +``` +--- END EXAMPLE2 --- + +**Instructions**: +1. Write a program to load the media and call each tool and print it's output along with other relevant information. +2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary. +3. Your test case MUST run only on the given images which are {media} +4. Print this final dictionary. +5. For video input, sample at 1 FPS and use the first 10 frames only to reduce processing time. +""" + +PREVIOUS_FAILED = """ +**Previous Failed Attempts**: +You previously ran this code: +```python +{code} +``` + +But got the following error or no stdout: +{error} +""" + +PICK_PLAN = """ +**Role**: You are an advanced AI model that can understand the user request and construct plans to accomplish it. + +**Task**: Your responsibility is to pick the best plan from the three plans provided. + +**Context**: +{context} + +**Plans**: +{plans} + +**Tool Output**: +{tool_output} + +**Instructions**: +1. Re-read the user request, plans, tool outputs and examine the image. +2. Solve the problem yourself given the image and pick the most accurate plan that matches your solution the best. +3. Add modifications to improve the plan including: changing a tool, adding thresholds, string matching. +3. Output a JSON object with the following format: +{{ + "predicted_answer": str # the answer you would expect from the best plan + "thoughts": str # your thought process for choosing the best plan over other plans and any modifications you made + "best_plan": str # the best plan you have chosen, must be `plan1`, `plan2`, or `plan3` +}} +""" diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index da74f677..2a75aa2b 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,6 +1,5 @@ from typing import Callable, List, Optional -from .meta_tools import META_TOOL_DOCSTRING, Artifacts from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT from .tool_utils import get_tool_descriptions_by_names from .tools import ( diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index c9fc7be0..0fb46cee 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -13,7 +13,9 @@ from IPython.display import display import vision_agent as va +from vision_agent.agent.agent_utils import extract_json from vision_agent.clients.landing_public_api import LandingPublicAPI +from vision_agent.lmm import AnthropicLMM from vision_agent.lmm.types import Message from vision_agent.tools.tool_utils import get_tool_documentation from vision_agent.tools.tools import TOOL_DESCRIPTIONS @@ -338,6 +340,85 @@ def edit_code_artifact( return open_code_artifact(artifacts, name, cur_line) +def generate_vision_plan( + artifacts: Artifacts, + name: str, + chat: str, + media: List[str], + test_multi_plan: bool = True, + custom_tool_names: Optional[List[str]] = None, +) -> str: + """Generates a plan to solve vision based tasks. + + Parameters: + artifacts (Artifacts): The artifacts object to save the plan to. + name (str): The name of the artifact to save the plan context to. + chat (str): The chat message from the user. + media (List[str]): The media files to use. + test_multi_plan (bool): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. + + Returns: + str: The generated plan. + + Examples + -------- + >>> generate_vision_plan(artifacts, "plan.json", "Can you detect the dogs in this image?", ["image.jpg"]) + [Start Plan Context] + plan1: This is a plan to detect dogs in an image + -load image + -detect dogs + -return detections + [End Plan Context] + """ + + if ZMQ_PORT is not None: + agent = va.agent.VisionAgentPlanner( + report_progress_callback=lambda inp: report_progress_callback( + int(ZMQ_PORT), inp + ) + ) + else: + agent = va.agent.VisionAgentPlanner() + + fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}] + response = agent.generate_plan( + fixed_chat, + test_multi_plan=test_multi_plan, + custom_tool_names=custom_tool_names, + ) + if response.test_results is not None: + redisplay_results(response.test_results) + response.test_results = None + artifacts[name] = response.model_dump_json() + media_names = extract_json( + AnthropicLMM()( # type: ignore + f"""Extract any media file names from this output in the following JSON format: +{{"media": ["image1.jpg", "image2.jpg"]}} + +{artifacts[name]}""" + ) + ) + if "media" in media_names and isinstance(media_names, dict): + for media in media_names["media"]: + if isinstance(media, str): + with open(media, "rb") as f: + artifacts[media] = f.read() + + output_str = f"[Start Plan Context, saved at {name}]" + for plan in response.plans.keys(): + output_str += f"\n{plan}: {response.plans[plan]['thoughts'].strip()}\n" # type: ignore + output_str += " -" + "\n -".join( + e.strip() for e in response.plans[plan]["instructions"] + ) + + output_str += f"\nbest plan: {response.best_plan}\n" + output_str += "thoughts: " + response.plan_thoughts.strip() + "\n" + output_str += "[End Plan Context]" + print(output_str) + return output_str + + def generate_vision_code( artifacts: Artifacts, name: str, @@ -368,7 +449,6 @@ def detect_dogs(image_path: str): dogs = owl_v2("dog", image) return dogs """ - if ZMQ_PORT is not None: agent = va.agent.VisionAgentCoder( report_progress_callback=lambda inp: report_progress_callback( @@ -379,7 +459,7 @@ def detect_dogs(image_path: str): agent = va.agent.VisionAgentCoder(verbosity=int(VERBOSITY)) fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}] - response = agent.chat_with_workflow( + response = agent.generate_code( fixed_chat, test_multi_plan=test_multi_plan, custom_tool_names=custom_tool_names, @@ -459,7 +539,7 @@ def detect_dogs(image_path: str): fixed_chat_history.append({"role": "assistant", "content": code}) fixed_chat_history.append({"role": "user", "content": chat}) - response = agent.chat_with_workflow( + response = agent.generate_code( fixed_chat_history, test_multi_plan=False, custom_tool_names=customized_tool_names, @@ -748,6 +828,7 @@ def use_object_detection_fine_tuning( open_code_artifact, create_code_artifact, edit_code_artifact, + generate_vision_plan, generate_vision_code, edit_vision_code, write_media_artifact, From 5775fddba219ebd3740b6454ccfdfebb1dfb7f03 Mon Sep 17 00:00:00 2001 From: GitHub Actions Bot Date: Fri, 11 Oct 2024 03:53:37 +0000 Subject: [PATCH 3/3] [skip ci] chore(release): vision-agent 0.2.162 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e541b5c2..05222562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "vision-agent" -version = "0.2.161" +version = "0.2.162" description = "Toolset for Vision Agent" authors = ["Landing AI "] readme = "README.md"