Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Claude Sonnet 3.5 VisionAgentCoder #231

Merged
merged 22 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .vision_agent import VisionAgent
from .vision_agent_coder import (
AzureVisionAgentCoder,
ClaudeVisionAgentCoder,
OllamaVisionAgentCoder,
VisionAgentCoder,
)
48 changes: 30 additions & 18 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,48 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
if match:
json_str = match.group()
try:
# remove trailing comma
trailing_bracket_pattern = r",\s+\}"
json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL)

json_dict = json.loads(json_str)
return json_dict # type: ignore
except json.JSONDecodeError:
return None
return None


def _find_markdown_json(json_str: str) -> str:
pattern = r"```json(.*?)```"
match = re.search(pattern, json_str, re.DOTALL)
if match:
return match.group(1).strip()
return json_str


def _strip_markdown_code(inp_str: str) -> str:
pattern = r"```python.*?```"
cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL)
return cleaned_str


def extract_json(json_str: str) -> Dict[str, Any]:
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
json_str = json_str.replace("\n", " ").strip()

try:
json_str = json_str.replace("\n", " ")
json_dict = json.loads(json_str)
return json.loads(json_str) # type: ignore
except json.JSONDecodeError:
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
json_dict = _extract_sub_json(json_str)
if json_dict is not None:
return json_dict # type: ignore
error_msg = f"Could not extract JSON from the given str: {json_str}"
json_orig = json_str
json_str = _strip_markdown_code(json_str)
json_str = _find_markdown_json(json_str)
json_dict = _extract_sub_json(json_str)

if json_dict is None:
error_msg = f"Could not extract JSON from the given str: {json_orig}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e
raise ValueError(error_msg)

return json_dict # type: ignore
return json_dict


def extract_code(code: str) -> str:
Expand Down
112 changes: 86 additions & 26 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
TEST_PLANS,
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from vision_agent.lmm import (
LMM,
AzureOpenAILMM,
ClaudeSonnetLMM,
Message,
OllamaLMM,
OpenAILMM,
)
from vision_agent.tools.meta_tools import get_diff
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
Expand Down Expand Up @@ -167,9 +174,10 @@ def pick_plan(
}
)
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]
# Because of the way we trace function calls the trace information ends up in the
# results. We don't want to show this info to the LLM so we don't include it in the
# tool_output_str.
tool_output_str = tool_output.text(include_results=False).strip()

if verbosity == 2:
_print_code("Initial code and tests:", code)
Expand All @@ -196,7 +204,7 @@ def pick_plan(
docstring=tool_info,
plans=plan_str,
previous_attempts=PREVIOUS_FAILED.format(
code=code, error=tool_output.text()
code=code, error="\n".join(tool_output_str.splitlines()[-50:])
),
media=media,
)
Expand Down Expand Up @@ -225,11 +233,11 @@ def pick_plan(
"status": "completed" if tool_output.success else "failed",
}
)
tool_output_str = tool_output.text().strip()
tool_output_str = tool_output.text(include_results=False).strip()

if verbosity == 2:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempt {count}")
_LOGGER.info(f"Code execution result after attempt {count + 1}")

count += 1

Expand Down Expand Up @@ -387,7 +395,6 @@ def write_and_test_code(
"code": DefaultImports.prepend_imports(code),
"payload": {
"test": test,
# "result": result.to_json(),
},
}
)
Expand All @@ -406,6 +413,7 @@ def write_and_test_code(
working_memory,
debugger,
code_interpreter,
tool_info,
code,
test,
result,
Expand All @@ -431,6 +439,7 @@ def debug_code(
working_memory: List[Dict[str, str]],
debugger: LMM,
code_interpreter: CodeInterpreter,
tool_info: str,
code: str,
test: str,
result: Execution,
Expand All @@ -451,17 +460,38 @@ def debug_code(
count = 0
while not success and count < 3:
try:
fixed_code_and_test = extract_json(
debugger( # type: ignore
FIX_BUG.format(
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
feedback=format_memory(working_memory + new_working_memory),
# LLMs write worse code when it's in JSON, so we have it write JSON
# followed by code each wrapped in markdown blocks.
fixed_code_and_test_str = debugger(
FIX_BUG.format(
docstring=tool_info,
code=code,
tests=test,
# Because of the way we trace function calls the trace information
# ends up in the results. We don't want to show this info to the
# LLM so we don't include it in the tool_output_str.
result="\n".join(
result.text(include_results=False).splitlines()[-50:]
),
stream=False,
)
feedback=format_memory(working_memory + new_working_memory),
),
stream=False,
)
fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
fixed_code_and_test = extract_json(fixed_code_and_test_str)
code = extract_code(fixed_code_and_test_str)
if (
"which_code" in fixed_code_and_test
shankar-vision-eng marked this conversation as resolved.
Show resolved Hide resolved
and fixed_code_and_test["which_code"] == "test"
):
fixed_code_and_test["code"] = ""
fixed_code_and_test["test"] = code
else: # for everything else always assume it's updating code
fixed_code_and_test["code"] = code
fixed_code_and_test["test"] = ""
if "which_code" in fixed_code_and_test:
del fixed_code_and_test["which_code"]

success = True
except Exception as e:
_LOGGER.exception(f"Error while extracting JSON: {e}")
Expand All @@ -472,9 +502,9 @@ def debug_code(
old_test = test

if fixed_code_and_test["code"].strip() != "":
code = extract_code(fixed_code_and_test["code"])
code = fixed_code_and_test["code"]
if fixed_code_and_test["test"].strip() != "":
test = extract_code(fixed_code_and_test["test"])
test = fixed_code_and_test["test"]

new_working_memory.append(
{
Expand Down Expand Up @@ -628,9 +658,7 @@ def __init__(
)
self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
self.debugger = (
OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
)
self.debugger = OpenAILMM(temperature=0.0) if debugger is None else debugger
self.verbosity = verbosity
if self.verbosity > 0:
_LOGGER.setLevel(logging.INFO)
Expand Down Expand Up @@ -876,6 +904,40 @@ def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
)


class ClaudeVisionAgentCoder(VisionAgentCoder):
def __init__(
self,
planner: Optional[LMM] = None,
coder: Optional[LMM] = None,
tester: Optional[LMM] = None,
debugger: Optional[LMM] = None,
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
) -> None:
# NOTE: Claude doesn't have an official JSON mode
self.planner = ClaudeSonnetLMM(temperature=0.0) if planner is None else planner
self.coder = ClaudeSonnetLMM(temperature=0.0) if coder is None else coder
self.tester = ClaudeSonnetLMM(temperature=0.0) if tester is None else tester
self.debugger = (
ClaudeSonnetLMM(temperature=0.0) if debugger is None else debugger
)
self.verbosity = verbosity
if self.verbosity > 0:
_LOGGER.setLevel(logging.INFO)

# Anthropic does not offer any embedding models and instead recomends Voyage,
# we're using OpenAI's embedder for now.
self.tool_recommender = (
Sim(T.TOOLS_DF, sim_key="desc")
if tool_recommender is None
else tool_recommender
)
self.report_progress_callback = report_progress_callback
self.code_sandbox_runtime = code_sandbox_runtime


class OllamaVisionAgentCoder(VisionAgentCoder):
"""VisionAgentCoder that uses Ollama models for planning, coding, testing.

Expand Down Expand Up @@ -920,7 +982,7 @@ def __init__(
else tester
),
debugger=(
OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
OllamaLMM(model_name="llama3.1", temperature=0.0)
if debugger is None
else debugger
),
Expand Down Expand Up @@ -983,9 +1045,7 @@ def __init__(
coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
debugger=(
AzureOpenAILMM(temperature=0.0, json_mode=True)
if debugger is None
else debugger
AzureOpenAILMM(temperature=0.0) if debugger is None else debugger
),
tool_recommender=(
AzureSim(T.TOOLS_DF, sim_key="desc")
Expand Down
42 changes: 34 additions & 8 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
**Plans**:
{plans}

**Previous Attempts**:
{previous_attempts}

**Instructions**:
Expand Down Expand Up @@ -108,26 +109,38 @@
- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video.
plan3:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'countgd_counting' tool with the prompt 'person' to detect where the people are in the video.
- Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video.


```python
from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, countgd_counting
import numpy as np
from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, florence2_sam2_video_tracking

# sample at 1 FPS and use the first 10 frames to reduce processing time
frames = extract_frames("video.mp4", 1)
frames = [f[0] for f in frames][:10]

def remove_arrays(o):
if isinstance(o, list):
return [remove_arrays(e) for e in o]
elif isinstance(o, dict):
return {{k: remove_arrays(v) for k, v in o.items()}}
elif isinstance(o, np.ndarray):
return "array: " + str(o.shape)
else:
return o

# plan1
owl_v2_out = [owl_v2_image("person", f) for f in frames]

# plan2
florence2_out = [florence2_phrase_grounding("person", f) for f in frames]

# plan3
countgd_out = [countgd_counting(f) for f in frames]
f2s2_tracking_out = florence2_sam2_video_tracking("person", frames)
remove_arrays(f2s2_tracking_out)

final_out = {{"owl_v2_image": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
final_out = {{"owl_v2_image": owl_v2_out, "florence2_phrase_grounding": florence2_out, "florence2_sam2_video_tracking": f2s2_tracking_out}}
print(final_out)
```
"""
Expand Down Expand Up @@ -161,9 +174,10 @@

**Instructions**:
1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request.
2. Try solving the problem yourself given the image and pick the plan that matches your solution the best.
2. Sovle the problem yourself given the image and pick the plan that matches your solution the best.
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
3. Output a JSON object with the following format:
{{
"predicted_answer": str # the answer you would expect from the best plan
"thoughts": str # your thought process for choosing the best plan
"best_plan": str # the best plan you have chosen
}}
Expand Down Expand Up @@ -311,6 +325,11 @@ def find_text(image_path: str, text: str) -> str:
FIX_BUG = """
**Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting so you can run !pip install to install missing packages.

**Documentation**:
This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`.

{docstring}

**Instructions**:
Please re-complete the code to fix the error message. Here is the previous version:
```python
Expand All @@ -323,17 +342,24 @@ def find_text(image_path: str, text: str) -> str:
```

It raises this error:
```
{result}
```

This is previous feedback provided on the code:
{feedback}

Please fix the bug by follow the error information and return a JSON object with the following format:
Please fix the bug by correcting the error. Return the following JSON object followed by the fixed code in the below format:
```json
{{
"reflections": str # any thoughts you have about the bug and how you fixed it
"code": str # the fixed code if any, else an empty string
"test": str # the fixed test code if any, else an empty string
"which_code": str # which code you fixed, can either be 'code' or 'test'
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
}}
```

```python
# Your fixed code here
```
"""


Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from base64 import b64encode
import inspect
import logging
import os
from base64 import b64encode
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple

import pandas as pd
Expand Down
Loading
Loading