diff --git a/poetry.lock b/poetry.lock index 7f1126b1..b6356fcf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -557,13 +557,13 @@ files = [ [[package]] name = "e2b" -version = "0.17.2a56" +version = "0.17.2a57" description = "E2B SDK that give agents cloud environments" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "e2b-0.17.2a56-py3-none-any.whl", hash = "sha256:19db2c8fce72f4fd08f7d5538184a0b237551817b9e80879b65503090a7b59b9"}, - {file = "e2b-0.17.2a56.tar.gz", hash = "sha256:7932ec1b7ab4e588d8769280698725085103b34eaca33e678d4b1a42bc2ff8fd"}, + {file = "e2b-0.17.2a57-py3-none-any.whl", hash = "sha256:db1bfd4cb65d10833faab2df386db35ed3fd7ab1ebee452414d8d006da848119"}, + {file = "e2b-0.17.2a57.tar.gz", hash = "sha256:92f77fdfa646ad83a40ed1e7bdc3c25fd76238eea016a5c96668f5f6d9807548"}, ] [package.dependencies] @@ -591,28 +591,6 @@ attrs = ">=21.3.0" e2b = ">=0.17.2a50,<0.18.0" httpx = ">=0.20.0,<0.28.0" -[[package]] -name = "eva-decord" -version = "0.6.1" -description = "EVA's Decord Video Loader" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "eva_decord-0.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a81c49d11c3f93c23b40fb106854d6c0b5548508e4b7971ade50c4d1ae4ad68f"}, - {file = "eva_decord-0.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0d7d4c6a698ac4ad3b14c3c85773bba8570d8a1431204a237365e17a940f48c7"}, - {file = "eva_decord-0.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f1e756887aa1833dadd0aee0f4e3b3dc10a9080b53a73001501c22eec311f78b"}, - {file = "eva_decord-0.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ae41d7958b7d6fc3af66ae1b4072d6f938abe04f2016b56891688ac8a78ee158"}, - {file = "eva_decord-0.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b2aae6fa0968ef5816fe09109aa87227cc5dbc5e3b0ae3a24c1de8d948776799"}, - {file = "eva_decord-0.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b44e20f401f4e7a52e6b1a6cb95fe06a40de4f02be5386da07c6d8f4851ab4ed"}, - {file = "eva_decord-0.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:af1a74f414fc84c35b45478aed7868b5afd323fb2b5c50e916ef7efa17524fb1"}, - {file = "eva_decord-0.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c64446dab22acb0ae44f3ee3190cb923fb538c74a4aa22a7fd8340ce3642c5cb"}, - {file = "eva_decord-0.6.1-py3-none-manylinux2010_x86_64.whl", hash = "sha256:75dabf364f2df5dc4c78d685cdeca29733ac422f53508a3c117f1387f1d0ef81"}, - {file = "eva_decord-0.6.1-py3-none-win_amd64.whl", hash = "sha256:f9f09369bef73075d945383bfaf1e41c3db118e7148719369ea134506e4bb525"}, -] - -[package.dependencies] -numpy = ">=1.14.0" - [[package]] name = "exceptiongroup" version = "1.2.2" @@ -657,19 +635,19 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.15.4" +version = "3.16.0" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, - {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, + {file = "filelock-3.16.0-py3-none-any.whl", hash = "sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609"}, + {file = "filelock-3.16.0.tar.gz", hash = "sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] -typing = ["typing-extensions (>=4.8)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.1.1)", "pytest (>=8.3.2)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.3)"] +typing = ["typing-extensions (>=4.12.2)"] [[package]] name = "flake8" @@ -1206,13 +1184,13 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout" [[package]] name = "langsmith" -version = "0.1.115" +version = "0.1.117" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.115-py3-none-any.whl", hash = "sha256:04e35cfd4c2d4ff1ea10bb577ff43957b05ebb3d9eb4e06e200701f4a2b4ac9f"}, - {file = "langsmith-0.1.115.tar.gz", hash = "sha256:3b775377d858d32354f3ee0dd1ed637068cfe9a1f13e7b3bfa82db1615cdffc9"}, + {file = "langsmith-0.1.117-py3-none-any.whl", hash = "sha256:e936ee9bcf8293b0496df7ba462a3702179fbe51f9dc28744b0fbec0dbf206ae"}, + {file = "langsmith-0.1.117.tar.gz", hash = "sha256:a1b532f49968b9339bcaff9118d141846d52ed3d803f342902e7448edf1d662b"}, ] [package.dependencies] @@ -1735,13 +1713,13 @@ files = [ [[package]] name = "openai" -version = "1.43.1" +version = "1.44.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.43.1-py3-none-any.whl", hash = "sha256:23ed3aa71e89cf644c911f7ab80087d08c0bf46ce6b75d9a811fc7942cff85c2"}, - {file = "openai-1.43.1.tar.gz", hash = "sha256:b64843711b7c92ded36795062ea1f8cad84ec6c2848646f2a786ac4617a6b9f5"}, + {file = "openai-1.44.1-py3-none-any.whl", hash = "sha256:07e2c2758d1c94151c740b14dab638ba0d04bcb41a2e397045c90e7661cdf741"}, + {file = "openai-1.44.1.tar.gz", hash = "sha256:e0ffdab601118329ea7529e684b606a72c6c9d4f05be9ee1116255fcf5593874"}, ] [package.dependencies] @@ -2156,19 +2134,19 @@ tests-min = ["defusedxml", "packaging", "pytest"] [[package]] name = "platformdirs" -version = "4.2.2" +version = "4.3.2" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, + {file = "platformdirs-4.3.2-py3-none-any.whl", hash = "sha256:eb1c8582560b34ed4ba105009a4badf7f6f85768b30126f351328507b2beb617"}, + {file = "platformdirs-4.3.2.tar.gz", hash = "sha256:9e5e27a08aa095dd127b9f2e764d74254f482fef22b0970773bfba79d091ab8c"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] [[package]] name = "pluggy" @@ -2425,13 +2403,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pydantic-settings" -version = "2.4.0" +version = "2.5.0" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"}, - {file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"}, + {file = "pydantic_settings-2.5.0-py3-none-any.whl", hash = "sha256:eae04a3dd9adf521a4c959dcefb984e0f3b1d841999daf02f961dcc4d31d2f7f"}, + {file = "pydantic_settings-2.5.0.tar.gz", hash = "sha256:204828c02481a2e7135466b26a7d65d9e15a17d52d1d8f59cacdf9ad625e1140"}, ] [package.dependencies] @@ -2924,13 +2902,13 @@ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asy [[package]] name = "rich" -version = "13.8.0" +version = "13.8.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.8.0-py3-none-any.whl", hash = "sha256:2e85306a063b9492dffc86278197a60cbece75bcb766022f3436f567cae11bdc"}, - {file = "rich-13.8.0.tar.gz", hash = "sha256:a5ac1f1cd448ade0d59cc3356f7db7a7ccda2c8cbae9c7a90c28ff463d3e91f4"}, + {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, + {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, ] [package.dependencies] @@ -3457,13 +3435,13 @@ files = [ [[package]] name = "types-requests" -version = "2.32.0.20240905" +version = "2.32.0.20240907" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.32.0.20240905.tar.gz", hash = "sha256:e97fd015a5ed982c9ddcd14cc4afba9d111e0e06b797c8f776d14602735e9bd6"}, - {file = "types_requests-2.32.0.20240905-py3-none-any.whl", hash = "sha256:f46ecb55f5e1a37a58be684cf3f013f166da27552732ef2469a0cc8e62a72881"}, + {file = "types-requests-2.32.0.20240907.tar.gz", hash = "sha256:ff33935f061b5e81ec87997e91050f7b4af4f82027a7a7a9d9aaea04a963fdf8"}, + {file = "types_requests-2.32.0.20240907-py3-none-any.whl", hash = "sha256:1d1e79faeaf9d42def77f3c304893dea17a97cae98168ac69f3cb465516ee8da"}, ] [package.dependencies] @@ -3532,13 +3510,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.26.3" +version = "20.26.4" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.3-py3-none-any.whl", hash = "sha256:8cc4a31139e796e9a7de2cd5cf2489de1217193116a8fd42328f1bd65f434589"}, - {file = "virtualenv-20.26.3.tar.gz", hash = "sha256:4c43a2a236279d9ea36a0d76f98d84bd6ca94ac4e0f4a3b9d46d05e10fea542a"}, + {file = "virtualenv-20.26.4-py3-none-any.whl", hash = "sha256:48f2695d9809277003f30776d155615ffc11328e6a0a8c1f0ec80188d7874a55"}, + {file = "virtualenv-20.26.4.tar.gz", hash = "sha256:c17f4e0f3e6036e9f26700446f85c76ab11df65ff6d8a9cbfad9f71aabfcf23c"}, ] [package.dependencies] @@ -3625,4 +3603,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "f9ebed539e44012292a6637d32a8a649dd44ad37f1eab9fb41b12493c700cdc0" +content-hash = "bead91bd0ca1f1b9ecca03980370fbf63bcd345599e89bbd4b5b412c53de3b9f" diff --git a/pyproject.toml b/pyproject.toml index db53410a..d3fe0be1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ pillow-heif = "^0.16.0" pytube = "15.0.0" anthropic = "^0.31.0" pydantic = "2.7.4" -eva-decord = "^0.6.1" av = "^11.0.0" [tool.poetry.group.dev.dependencies] diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 24bd259f..ba5b989e 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -22,6 +22,7 @@ grounding_sam, ixc25_image_vqa, ixc25_video_vqa, + ixc25_temporal_localization, loca_visual_prompt_counting, loca_zero_shot_counting, ocr, @@ -238,6 +239,17 @@ def test_ixc25_video_vqa() -> None: assert "cat" in result.strip() +def test_ixc25_temporal_localization() -> None: + frames = [ + np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10) + ] + result = ixc25_temporal_localization( + prompt="What animal is in this video?", + frames=frames, + ) + assert result == [True] * 10 + + def test_ocr() -> None: img = ski.data.page() result = ocr( diff --git a/tests/unit/tools/test_video.py b/tests/unit/tools/test_video.py index 4dfcf54b..2ef1fe21 100644 --- a/tests/unit/tools/test_video.py +++ b/tests/unit/tools/test_video.py @@ -6,5 +6,5 @@ def test_extract_frames_from_video(): video_path = "tests/data/video/test.mp4" # there are 48 frames at 24 fps in this video file - res = extract_frames_from_video(video_path) - assert len(res) == 2 + res = extract_frames_from_video(video_path, fps=24) + assert len(res) == 48 diff --git a/vision_agent/agent/__init__.py b/vision_agent/agent/__init__.py index 2164d688..e1478a38 100644 --- a/vision_agent/agent/__init__.py +++ b/vision_agent/agent/__init__.py @@ -2,6 +2,7 @@ from .vision_agent import VisionAgent from .vision_agent_coder import ( AzureVisionAgentCoder, + ClaudeVisionAgentCoder, OllamaVisionAgentCoder, VisionAgentCoder, ) diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index eb951ccc..2a193a4a 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -14,6 +14,10 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: if match: json_str = match.group() try: + # remove trailing comma + trailing_bracket_pattern = r",\s+\}" + json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL) + json_dict = json.loads(json_str) return json_dict # type: ignore except json.JSONDecodeError: @@ -21,29 +25,37 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: return None +def _find_markdown_json(json_str: str) -> str: + pattern = r"```json(.*?)```" + match = re.search(pattern, json_str, re.DOTALL) + if match: + return match.group(1).strip() + return json_str + + +def _strip_markdown_code(inp_str: str) -> str: + pattern = r"```python.*?```" + cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL) + return cleaned_str + + def extract_json(json_str: str) -> Dict[str, Any]: + json_str = json_str.replace("\n", " ").strip() + try: - json_str = json_str.replace("\n", " ") - json_dict = json.loads(json_str) + return json.loads(json_str) # type: ignore except json.JSONDecodeError: - if "```json" in json_str: - json_str = json_str[json_str.find("```json") + len("```json") :] - json_str = json_str[: json_str.find("```")] - elif "```" in json_str: - json_str = json_str[json_str.find("```") + len("```") :] - # get the last ``` not one from an intermediate string - json_str = json_str[: json_str.find("}```")] - try: - json_dict = json.loads(json_str) - except json.JSONDecodeError as e: - json_dict = _extract_sub_json(json_str) - if json_dict is not None: - return json_dict # type: ignore - error_msg = f"Could not extract JSON from the given str: {json_str}" + json_orig = json_str + json_str = _strip_markdown_code(json_str) + json_str = _find_markdown_json(json_str) + json_dict = _extract_sub_json(json_str) + + if json_dict is None: + error_msg = f"Could not extract JSON from the given str: {json_orig}" _LOGGER.exception(error_msg) - raise ValueError(error_msg) from e + raise ValueError(error_msg) - return json_dict # type: ignore + return json_dict def extract_code(code: str) -> str: diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 8b6f9032..edf249ac 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -27,7 +27,14 @@ TEST_PLANS, USER_REQ, ) -from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM +from vision_agent.lmm import ( + LMM, + AzureOpenAILMM, + ClaudeSonnetLMM, + Message, + OllamaLMM, + OpenAILMM, +) from vision_agent.tools.meta_tools import get_diff from vision_agent.utils import CodeInterpreterFactory, Execution from vision_agent.utils.execute import CodeInterpreter @@ -167,9 +174,10 @@ def pick_plan( } ) tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code)) - tool_output_str = "" - if len(tool_output.logs.stdout) > 0: - tool_output_str = tool_output.logs.stdout[0] + # Because of the way we trace function calls the trace information ends up in the + # results. We don't want to show this info to the LLM so we don't include it in the + # tool_output_str. + tool_output_str = tool_output.text(include_results=False).strip() if verbosity == 2: _print_code("Initial code and tests:", code) @@ -196,7 +204,7 @@ def pick_plan( docstring=tool_info, plans=plan_str, previous_attempts=PREVIOUS_FAILED.format( - code=code, error=tool_output.text() + code=code, error="\n".join(tool_output_str.splitlines()[-50:]) ), media=media, ) @@ -225,11 +233,11 @@ def pick_plan( "status": "completed" if tool_output.success else "failed", } ) - tool_output_str = tool_output.text().strip() + tool_output_str = tool_output.text(include_results=False).strip() if verbosity == 2: _print_code("Code and test after attempted fix:", code) - _LOGGER.info(f"Code execution result after attempt {count}") + _LOGGER.info(f"Code execution result after attempt {count + 1}") count += 1 @@ -387,7 +395,6 @@ def write_and_test_code( "code": DefaultImports.prepend_imports(code), "payload": { "test": test, - # "result": result.to_json(), }, } ) @@ -406,6 +413,7 @@ def write_and_test_code( working_memory, debugger, code_interpreter, + tool_info, code, test, result, @@ -431,6 +439,7 @@ def debug_code( working_memory: List[Dict[str, str]], debugger: LMM, code_interpreter: CodeInterpreter, + tool_info: str, code: str, test: str, result: Execution, @@ -451,17 +460,38 @@ def debug_code( count = 0 while not success and count < 3: try: - fixed_code_and_test = extract_json( - debugger( # type: ignore - FIX_BUG.format( - code=code, - tests=test, - result="\n".join(result.text().splitlines()[-50:]), - feedback=format_memory(working_memory + new_working_memory), + # LLMs write worse code when it's in JSON, so we have it write JSON + # followed by code each wrapped in markdown blocks. + fixed_code_and_test_str = debugger( + FIX_BUG.format( + docstring=tool_info, + code=code, + tests=test, + # Because of the way we trace function calls the trace information + # ends up in the results. We don't want to show this info to the + # LLM so we don't include it in the tool_output_str. + result="\n".join( + result.text(include_results=False).splitlines()[-50:] ), - stream=False, - ) + feedback=format_memory(working_memory + new_working_memory), + ), + stream=False, ) + fixed_code_and_test_str = cast(str, fixed_code_and_test_str) + fixed_code_and_test = extract_json(fixed_code_and_test_str) + code = extract_code(fixed_code_and_test_str) + if ( + "which_code" in fixed_code_and_test + and fixed_code_and_test["which_code"] == "test" + ): + fixed_code_and_test["code"] = "" + fixed_code_and_test["test"] = code + else: # for everything else always assume it's updating code + fixed_code_and_test["code"] = code + fixed_code_and_test["test"] = "" + if "which_code" in fixed_code_and_test: + del fixed_code_and_test["which_code"] + success = True except Exception as e: _LOGGER.exception(f"Error while extracting JSON: {e}") @@ -472,9 +502,9 @@ def debug_code( old_test = test if fixed_code_and_test["code"].strip() != "": - code = extract_code(fixed_code_and_test["code"]) + code = fixed_code_and_test["code"] if fixed_code_and_test["test"].strip() != "": - test = extract_code(fixed_code_and_test["test"]) + test = fixed_code_and_test["test"] new_working_memory.append( { @@ -628,9 +658,7 @@ def __init__( ) self.coder = OpenAILMM(temperature=0.0) if coder is None else coder self.tester = OpenAILMM(temperature=0.0) if tester is None else tester - self.debugger = ( - OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger - ) + self.debugger = OpenAILMM(temperature=0.0) if debugger is None else debugger self.verbosity = verbosity if self.verbosity > 0: _LOGGER.setLevel(logging.INFO) @@ -876,6 +904,40 @@ def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None: ) +class ClaudeVisionAgentCoder(VisionAgentCoder): + def __init__( + self, + planner: Optional[LMM] = None, + coder: Optional[LMM] = None, + tester: Optional[LMM] = None, + debugger: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_sandbox_runtime: Optional[str] = None, + ) -> None: + # NOTE: Claude doesn't have an official JSON mode + self.planner = ClaudeSonnetLMM(temperature=0.0) if planner is None else planner + self.coder = ClaudeSonnetLMM(temperature=0.0) if coder is None else coder + self.tester = ClaudeSonnetLMM(temperature=0.0) if tester is None else tester + self.debugger = ( + ClaudeSonnetLMM(temperature=0.0) if debugger is None else debugger + ) + self.verbosity = verbosity + if self.verbosity > 0: + _LOGGER.setLevel(logging.INFO) + + # Anthropic does not offer any embedding models and instead recomends Voyage, + # we're using OpenAI's embedder for now. + self.tool_recommender = ( + Sim(T.TOOLS_DF, sim_key="desc") + if tool_recommender is None + else tool_recommender + ) + self.report_progress_callback = report_progress_callback + self.code_sandbox_runtime = code_sandbox_runtime + + class OllamaVisionAgentCoder(VisionAgentCoder): """VisionAgentCoder that uses Ollama models for planning, coding, testing. @@ -920,7 +982,7 @@ def __init__( else tester ), debugger=( - OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True) + OllamaLMM(model_name="llama3.1", temperature=0.0) if debugger is None else debugger ), @@ -983,9 +1045,7 @@ def __init__( coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder, tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester, debugger=( - AzureOpenAILMM(temperature=0.0, json_mode=True) - if debugger is None - else debugger + AzureOpenAILMM(temperature=0.0) if debugger is None else debugger ), tool_recommender=( AzureSim(T.TOOLS_DF, sim_key="desc") diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index df68372c..6d7b18d6 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -63,6 +63,7 @@ **Plans**: {plans} +**Previous Attempts**: {previous_attempts} **Instructions**: @@ -108,16 +109,27 @@ - Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video. plan3: - Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool. -- Use the 'countgd_counting' tool with the prompt 'person' to detect where the people are in the video. +- Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video. ```python -from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, countgd_counting +import numpy as np +from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, florence2_sam2_video_tracking # sample at 1 FPS and use the first 10 frames to reduce processing time frames = extract_frames("video.mp4", 1) frames = [f[0] for f in frames][:10] +def remove_arrays(o): + if isinstance(o, list): + return [remove_arrays(e) for e in o] + elif isinstance(o, dict): + return {{k: remove_arrays(v) for k, v in o.items()}} + elif isinstance(o, np.ndarray): + return "array: " + str(o.shape) + else: + return o + # plan1 owl_v2_out = [owl_v2_image("person", f) for f in frames] @@ -125,9 +137,10 @@ florence2_out = [florence2_phrase_grounding("person", f) for f in frames] # plan3 -countgd_out = [countgd_counting(f) for f in frames] +f2s2_tracking_out = florence2_sam2_video_tracking("person", frames) +remove_arrays(f2s2_tracking_out) -final_out = {{"owl_v2_image": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}} +final_out = {{"owl_v2_image": owl_v2_out, "florence2_phrase_grounding": florence2_out, "florence2_sam2_video_tracking": f2s2_tracking_out}} print(final_out) ``` """ @@ -161,9 +174,10 @@ **Instructions**: 1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request. -2. Try solving the problem yourself given the image and pick the plan that matches your solution the best. +2. Solve the problem yourself given the image and pick the plan that matches your solution the best. 3. Output a JSON object with the following format: {{ + "predicted_answer": str # the answer you would expect from the best plan "thoughts": str # your thought process for choosing the best plan "best_plan": str # the best plan you have chosen }} @@ -311,6 +325,11 @@ def find_text(image_path: str, text: str) -> str: FIX_BUG = """ **Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting so you can run !pip install to install missing packages. +**Documentation**: +This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`. + +{docstring} + **Instructions**: Please re-complete the code to fix the error message. Here is the previous version: ```python @@ -323,17 +342,24 @@ def find_text(image_path: str, text: str) -> str: ``` It raises this error: +``` {result} +``` This is previous feedback provided on the code: {feedback} -Please fix the bug by follow the error information and return a JSON object with the following format: +Please fix the bug by correcting the error. Return the following JSON object followed by the fixed code in the below format: +```json {{ "reflections": str # any thoughts you have about the bug and how you fixed it - "code": str # the fixed code if any, else an empty string - "test": str # the fixed test code if any, else an empty string + "which_code": str # the code that was fixed, can only be 'code' or 'test' }} +``` + +```python +# Your fixed code here +``` """ diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 4f42380c..d075dad5 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -371,7 +371,7 @@ class ClaudeSonnetLMM(LMM): def __init__( self, api_key: Optional[str] = None, - model_name: str = "claude-3-sonnet-20240229", + model_name: str = "claude-3-5-sonnet-20240620", max_tokens: int = 4096, **kwargs: Any, ): diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index f7b1e4c0..a401fb46 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -37,6 +37,7 @@ grounding_dino, grounding_sam, ixc25_image_vqa, + ixc25_temporal_localization, ixc25_video_vqa, load_image, loca_visual_prompt_counting, diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 534bc078..1b85446b 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,7 +1,7 @@ -from base64 import b64encode import inspect import logging import os +from base64 import b64encode from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 7092a646..63927f01 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -468,7 +468,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]: pil_image = Image.fromarray(image).convert("RGB") image_size = pil_image.size[::-1] - if image_size[0] < 1 and image_size[1] < 1: + if image_size[0] < 1 or image_size[1] < 1: return [] image_buffer = io.BytesIO() pil_image.save(image_buffer, format="PNG") @@ -781,6 +781,44 @@ def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str: return cast(str, data["answer"]) +def ixc25_temporal_localization(prompt: str, frames: List[np.ndarray]) -> List[bool]: + """'ixc25_temporal_localization' uses ixc25_video_vqa to temporally segment a video + given a prompt that can be other an object or a phrase. It returns a list of + boolean values indicating whether the object or phrase is present in the + corresponding frame. + + Parameters: + prompt (str): The question about the video + frames (List[np.ndarray]): The reference frames used for the question + + Returns: + List[bool]: A list of boolean values indicating whether the object or phrase is + present in the corresponding frame. + + Example + ------- + >>> output = ixc25_temporal_localization('soccer goal', frames) + >>> print(output) + [False, False, False, True, True, True, False, False, False, False] + >>> save_video([f for i, f in enumerate(frames) if output[i]], 'output.mp4') + """ + + buffer_bytes = frames_to_bytes(frames) + files = [("video", buffer_bytes)] + payload = { + "prompt": prompt, + "chunk_length": 2, + "function_name": "ixc25_temporal_localization", + } + data: List[int] = send_inference_request( + payload, "video-temporal-localization", files=files, v2=True + ) + chunk_size = round(len(frames) / len(data)) + data_explode = [[elt] * chunk_size for elt in data] + data_bool = [bool(elt) for sublist in data_explode for elt in sublist] + return data_bool[: len(frames)] + + def gpt4o_image_vqa(prompt: str, image: np.ndarray) -> str: """'gpt4o_image_vqa' is a tool that can answer any questions about arbitrary images including regular images or images of documents or presentations. It returns text @@ -1112,6 +1150,8 @@ def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]: """ image_size = image.shape[:2] + if image_size[0] < 1 or image_size[1] < 1: + return [] image_b64 = convert_to_b64(image) data = { "image": image_b64, @@ -1467,7 +1507,7 @@ def extract_frames( Parameters: video_uri (Union[str, Path]): The path to the video file, url or youtube link fps (float, optional): The frame rate per second to extract the frames. Defaults - to 10. + to 1. Returns: List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 7967367e..c2e0e652 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -292,7 +292,7 @@ class Config: error: Optional[Error] = None "Error object if an error occurred, None otherwise." - def text(self, include_logs: bool = True) -> str: + def text(self, include_logs: bool = True, include_results: bool = True) -> str: """Returns the text representation of this object, i.e. including the main result or the error traceback, optionally along with the logs (stdout, stderr). """ @@ -300,15 +300,17 @@ def text(self, include_logs: bool = True) -> str: if self.error: return prefix + "\n----- Error -----\n" + self.error.traceback - result_str = [ - ( - f"----- Final output -----\n{res.text}" - if res.is_main_result - else f"----- Intermediate output-----\n{res.text}" - ) - for res in self.results - ] - return prefix + "\n" + "\n".join(result_str) + if include_results: + result_str = [ + ( + f"----- Final output -----\n{res.text}" + if res.is_main_result + else f"----- Intermediate output-----\n{res.text}" + ) + for res in self.results + ] + return prefix + "\n" + "\n".join(result_str) + return prefix @property def success(self) -> bool: diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index d306f295..ba6b0c76 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -7,7 +7,6 @@ import av # type: ignore import cv2 import numpy as np -from decord import VideoReader # type: ignore _LOGGER = logging.getLogger(__name__) # The maximum length of the clip to extract frames from, in seconds @@ -103,7 +102,7 @@ def frames_to_bytes( def extract_frames_from_video( video_uri: str, fps: float = 1.0 ) -> List[Tuple[np.ndarray, float]]: - """Extract frames from a video + """Extract frames from a video along with the timestamp in seconds. Parameters: video_uri (str): the path to the video file or a video file url @@ -115,12 +114,24 @@ def extract_frames_from_video( from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. """ - vr = VideoReader(video_uri) - orig_fps = vr.get_avg_fps() - if fps > orig_fps: - fps = orig_fps - - s = orig_fps / fps - samples = [(int(i * s), int(i * s) / orig_fps) for i in range(int(len(vr) / s))] - frames = vr.get_batch([s[0] for s in samples]).asnumpy() - return [(frames[i, :, :, :], samples[i][1]) for i in range(len(samples))] + + cap = cv2.VideoCapture(video_uri) + orig_fps = cap.get(cv2.CAP_PROP_FPS) + orig_frame_time = 1 / orig_fps + targ_frame_time = 1 / fps + frames: List[Tuple[np.ndarray, float]] = [] + i = 0 + elapsed_time = 0.0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + elapsed_time += orig_frame_time + if elapsed_time >= targ_frame_time: + frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps)) + elapsed_time -= targ_frame_time + + i += 1 + cap.release() + return frames