Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full Claude Sonnet 3.5 Support #234

Merged
merged 30 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1b3d8a8
resize image for claude
dillonalaird Sep 11, 2024
2b112b5
only resize if above size
dillonalaird Sep 11, 2024
9e71697
renamed claude to anthropic for consistency
dillonalaird Sep 11, 2024
e4485fa
added openai classes and made anthropic default
dillonalaird Sep 11, 2024
bd8d245
add ability to view images
dillonalaird Sep 11, 2024
d64f86d
add florence2 fine tune to owl_v2 args
dillonalaird Sep 11, 2024
ad54abb
added fine tune id for florence2sam2
dillonalaird Sep 11, 2024
7432e10
add generic OD fine tuning
dillonalaird Sep 11, 2024
7d27d63
fixed type error
dillonalaird Sep 11, 2024
e13d019
added comment
dillonalaird Sep 11, 2024
c3c210b
fix prompt for florence2 sam2 video tracking
dillonalaird Sep 11, 2024
39a8548
fixed import bug
dillonalaird Sep 12, 2024
30f00f7
updated fine tuning names in prompts
dillonalaird Sep 12, 2024
51ca06b
improve json parsing
dillonalaird Sep 17, 2024
891def5
update json extract, add tests
dillonalaird Sep 17, 2024
0d9c00b
removed old code
dillonalaird Sep 19, 2024
54785de
minor improvements to prompt to improve benchmark
dillonalaird Sep 19, 2024
fb5cfc3
pass plan thoughts to coder
dillonalaird Sep 19, 2024
64df1e8
fixed comments
dillonalaird Sep 19, 2024
14fc101
fix type and lint errors
dillonalaird Sep 19, 2024
957ed56
update tests
dillonalaird Sep 19, 2024
152ac13
make imports easier, pass more code info
dillonalaird Sep 22, 2024
c4ee089
update prompts
dillonalaird Sep 22, 2024
94f9501
standardize fps to 1
dillonalaird Sep 22, 2024
85e2e8a
rename functions to make them easier to understand by llm
dillonalaird Sep 22, 2024
9b26db3
add openai vision agent coder
dillonalaird Sep 22, 2024
921d3b7
fix complexity
dillonalaird Sep 22, 2024
4d37e30
fix type issue
dillonalaird Sep 22, 2024
b11cb88
fix lmm version
dillonalaird Sep 22, 2024
2c9c5c5
updated readme
dillonalaird Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
ixc25_temporal_localization,
ixc25_video_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
Expand All @@ -33,6 +33,8 @@
vit_nsfw_classification,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"


def test_grounding_dino():
img = ski.data.coins()
Expand Down Expand Up @@ -65,6 +67,18 @@ def test_owl_v2_image():
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_owl_v2_fine_tune_id():
img = ski.data.coins()
result = owl_v2_image(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_owl_v2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
Expand All @@ -78,7 +92,7 @@ def test_owl_v2_video():
assert 24 <= len([res["label"] for res in result[0]]) <= 26


def test_object_detection():
def test_florence2_phrase_grounding():
img = ski.data.coins()
result = florence2_phrase_grounding(
image=img,
Expand All @@ -88,6 +102,18 @@ def test_object_detection():
assert [res["label"] for res in result] == ["coin"] * 25


def test_florence2_phrase_grounding_fine_tune_id():
img = ski.data.coins()
result = florence2_phrase_grounding(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_template_match():
img = ski.data.coins()
result = template_match(
Expand Down Expand Up @@ -119,6 +145,19 @@ def test_florence2_sam2_image():
assert len([res["mask"] for res in result]) == 25


def test_florence2_sam2_image_fine_tune_id():
img = ski.data.coins()
result = florence2_sam2_image(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert len([res["mask"] for res in result]) == len(result)


def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from vision_agent.agent.agent_utils import extract_code, extract_json


def test_basic_json_extract():
a = '{"a": 1, "b": 2}'
assert extract_json(a) == {"a": 1, "b": 2}


def test_side_case_quotes_json_extract():
a = "{'0': 'no', '3': 'no', '6': 'no', '9': 'yes', '12': 'no', '15': 'no'}"
a_json = extract_json(a)
assert len(a_json) == 6


def test_side_case_bool_json_extract():
a = "{'0': False, '3': False, '6': False, '9': True, '12': False, '15': False}"
a_json = extract_json(a)
assert len(a_json) == 6


def test_complicated_case_json_extract_1():
a = """```json { "plan1": { "thoughts": "This plan uses the owl_v2_video tool to detect the truck and then uses ocr to read the USDOT and trailer numbers. This approach is efficient as it can process the entire video at once for truck detection.", "instructions": [ "Use extract_frames to get frames from truck1.mp4", "Use owl_v2_video with prompt 'truck' to detect if a truck is present in the video", "If a truck is detected, use ocr on relevant frames to read the USDOT and trailer numbers", "Process the OCR results to extract the USDOT and trailer numbers", "Compile results into JSON format and save using save_json" ] }, "plan2": { "thoughts": "This plan uses florence2_sam2_video_tracking to segment and track the truck, then uses florence2_ocr for text detection. This approach might be more accurate for text detection as it can focus on the relevant parts of the truck.", "instructions": [ "Use extract_frames to get frames from truck1.mp4", "Use florence2_sam2_video_tracking with prompt 'truck' to segment and track the truck", "If a truck is segmented, use florence2_ocr on the segmented area to detect text", "Process the OCR results to extract the USDOT and trailer numbers", "Compile results into JSON format and save using save_json" ] }, "plan3": { "thoughts": "This plan uses ixc25_video_vqa to directly ask questions about the truck, USDOT number, and trailer number. This approach leverages the model's ability to understand and answer complex questions about video content.", "instructions": [ "Use extract_frames to get frames from truck1.mp4", "Use ixc25_video_vqa with the question 'Is there a truck in this video?' to detect the presence of a truck", "If a truck is present, use ixc25_video_vqa with the question 'What is the USDOT number on the truck?'", "Use ixc25_video_vqa with the question 'What is the trailer number on the truck?'", "Process the answers to extract the required information", "Compile results into JSON format and save using save_json" ] } } ```"""

a_json = extract_json(a)
assert len(a_json) == 3
assert "plan1" in a_json


def test_complicated_case_json_extract_2():
a = """{\n "predicted_answer": "2",\n "thoughts": "After analyzing the image and the tool outputs, I can see that there are indeed 2 dogs in the image. One is a small grey dog on the grass, and the other is a larger white dog on the patio. All three plans correctly identified 2 dogs, but I believe plan2 using the countgd_counting tool is the best choice for this task. Here\'s why:\n\n 1. Accuracy: The countgd_counting tool provided high confidence scores (0.92 and 0.9) for both dogs, which aligns with what I can see in the image.\n \n 2. Precision: The bounding boxes from the countgd_counting tool seem to be more precise and tightly fit around the dogs compared to the other tools.\n \n 3. Simplicity: While plan3 offers a more complex approach with additional verification, it\'s not necessary in this case as the dogs are clearly visible and easily identifiable. The extra steps in plan3 would add unnecessary complexity and potential for errors.\n \n 4. Efficiency: Plan2 is more straightforward and efficient than plan3, while potentially offering better accuracy than plan1 (owl_v2_image tool had lower confidence scores).",\n "best_plan": "plan2"\n}"""
a_json = extract_json(a)
assert len(a_json) == 3
assert "predicted_answer" in a_json


def test_basic_code_extract():
a = """```python
def test_basic_json_extract():
a = '{"a": 1, "b": 2}'
assert extract_json(a) == {"a": 1, "b": 2}
```
"""
a_code = extract_code(a)
assert "def test_basic_json_extract():" in a_code
assert "assert extract_json(a) == {" in a_code
3 changes: 2 additions & 1 deletion vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .agent import Agent
from .vision_agent import VisionAgent
from .vision_agent_coder import (
AnthropicVisionAgentCoder,
AzureVisionAgentCoder,
ClaudeVisionAgentCoder,
OllamaVisionAgentCoder,
OpenAIVisionAgentCoder,
VisionAgentCoder,
)
10 changes: 8 additions & 2 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@ def _strip_markdown_code(inp_str: str) -> str:


def extract_json(json_str: str) -> Dict[str, Any]:
json_str = json_str.replace("\n", " ").strip()
json_str_mod = json_str.replace("\n", " ").strip()
json_str_mod = json_str_mod.replace("'", '"')
json_str_mod = json_str_mod.replace(": True", ": true").replace(
": False", ": false"
)

try:
return json.loads(json_str) # type: ignore
return json.loads(json_str_mod) # type: ignore
except json.JSONDecodeError:
json_orig = json_str
# don't replace quotes here or booleans since it can also introduce errors
json_str = json_str.replace("\n", " ").strip()
json_str = _strip_markdown_code(json_str)
json_str = _find_markdown_json(json_str)
json_dict = _extract_sub_json(json_str)
Expand Down
114 changes: 97 additions & 17 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import extract_json
from vision_agent.agent.vision_agent_prompts import (
EXAMPLES_CODE1,
EXAMPLES_CODE2,
EXAMPLES_CODE3,
VA_CODE,
)
from vision_agent.lmm import LMM, Message, OpenAILMM
from vision_agent.lmm import LMM, AnthropicLMM, Message, OpenAILMM
from vision_agent.tools import META_TOOL_DOCSTRING
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
from vision_agent.tools.meta_tools import (
Artifacts,
check_and_load_image,
use_extra_vision_agent_args,
)
from vision_agent.utils import CodeInterpreterFactory
from vision_agent.utils.execute import CodeInterpreter, Execution

Expand All @@ -30,7 +35,7 @@ class BoilerplateCode:
pre_code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning",
"artifacts = Artifacts('{remote_path}')",
"artifacts.load('{remote_path}')",
]
Expand Down Expand Up @@ -68,10 +73,18 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:

prompt = VA_CODE.format(
documentation=META_TOOL_DOCSTRING,
examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}",
examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}",
conversation=conversation,
)
return extract_json(orch([{"role": "user", "content": prompt}], stream=False)) # type: ignore
message: Message = {"role": "user", "content": prompt}
# only add recent media so we don't overload the model with old images
if (
chat[-1]["role"] == "observation"
and "media" in chat[-1]
and len(chat[-1]["media"]) > 0 # type: ignore
):
message["media"] = chat[-1]["media"]
return extract_json(orch([message], stream=False)) # type: ignore


def run_code_action(
Expand Down Expand Up @@ -136,10 +149,8 @@ def __init__(
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
"""

self.agent = (
OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent
)
self.max_iterations = 100
self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent
self.max_iterations = 12
shankar-vision-eng marked this conversation as resolved.
Show resolved Hide resolved
self.verbosity = verbosity
self.code_sandbox_runtime = code_sandbox_runtime
self.callback_message = callback_message
Expand Down Expand Up @@ -267,7 +278,8 @@ def chat_with_code(
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})

if isinstance(last_user_message_content, str):
if int_chat[-1]["role"] == "user":
last_user_message_content = cast(str, int_chat[-1].get("content", ""))
user_code_action = parse_execution(last_user_message_content, False)
if user_code_action is not None:
user_result, user_obs = run_code_action(
Expand Down Expand Up @@ -309,8 +321,7 @@ def chat_with_code(
else:
self.streaming_message({"role": "assistant", "content": response})

if response["let_user_respond"]:
break
finished = response["let_user_respond"]

code_action = parse_execution(
response["response"], test_multi_plan, customized_tool_names
Expand All @@ -321,13 +332,22 @@ def chat_with_code(
code_action, code_interpreter, str(remote_artifacts_path)
)

media_obs = check_and_load_image(code_action)

if self.verbosity >= 1:
_LOGGER.info(obs)

chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
chat_elt["media"] = [
Path(code_interpreter.remote_path) / media_ob
for media_ob in media_obs
]

# don't add execution results to internal chat
int_chat.append({"role": "observation", "content": obs})
orig_chat.append(
{"role": "observation", "content": obs, "execution": result}
)
int_chat.append(chat_elt)
chat_elt["execution"] = result
orig_chat.append(chat_elt)
self.streaming_message(
{
"role": "observation",
Expand All @@ -353,3 +373,63 @@ def streaming_message(self, message: Dict[str, Any]) -> None:

def log_progress(self, data: Dict[str, Any]) -> None:
pass


class OpenAIVisionAgent(VisionAgent):
def __init__(
self,
agent: Optional[LMM] = None,
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
"""Initialize the VisionAgent using OpenAI LMMs.

Parameters:
agent (Optional[LMM]): The agent to use for conversation and orchestration
of other agents.
verbosity (int): The verbosity level of the agent.
local_artifacts_path (Optional[Union[str, Path]]): The path to the local
artifacts file.
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
"""

agent = OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent
super().__init__(
agent,
verbosity,
local_artifacts_path,
code_sandbox_runtime,
callback_message,
)


class AnthropicVisionAgent(VisionAgent):
def __init__(
self,
agent: Optional[LMM] = None,
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
"""Initialize the VisionAgent using Anthropic LMMs.

Parameters:
agent (Optional[LMM]): The agent to use for conversation and orchestration
of other agents.
verbosity (int): The verbosity level of the agent.
local_artifacts_path (Optional[Union[str, Path]]): The path to the local
artifacts file.
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
"""

agent = AnthropicLMM(temperature=0.0) if agent is None else agent
super().__init__(
agent,
verbosity,
local_artifacts_path,
code_sandbox_runtime,
callback_message,
)
Loading
Loading