Skip to content

Commit

Permalink
Add Claude Sonnet 3.5 VisionAgentCoder (#231)
Browse files Browse the repository at this point in the history
* added ClaudeVisionAgentCoder and fixed json parser

* export ClaudeVisionAgentCoder

* fix prompts

* fixed type errors

* fixed erorr in default param for extract frames

* added function to strip results from remote code execution call

* allow debugger to see docs in case the bug was a missing import

* isort

* minor fix to prompt

* fix edge case for OCR

* debugger should not be json mode

* debugger should not be json mode

* update prompt to do better eval of different plans

* spelling mistake

* fixed prompt

* Add Temporal Localization & Fix Video Reading (#233)

* fixed issue with video reader

* added temporal localization

* fix video reader

* remove decord

* fix type error

* fix format issue

* fix test case

* type error

* add test case

* update version of claude
  • Loading branch information
dillonalaird authored Sep 11, 2024
1 parent 0805a40 commit 777b4d5
Show file tree
Hide file tree
Showing 14 changed files with 279 additions and 137 deletions.
92 changes: 35 additions & 57 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ pillow-heif = "^0.16.0"
pytube = "15.0.0"
anthropic = "^0.31.0"
pydantic = "2.7.4"
eva-decord = "^0.6.1"
av = "^11.0.0"

[tool.poetry.group.dev.dependencies]
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
ixc25_temporal_localization,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
Expand Down Expand Up @@ -238,6 +239,17 @@ def test_ixc25_video_vqa() -> None:
assert "cat" in result.strip()


def test_ixc25_temporal_localization() -> None:
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
result = ixc25_temporal_localization(
prompt="What animal is in this video?",
frames=frames,
)
assert result == [True] * 10


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/tools/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ def test_extract_frames_from_video():
video_path = "tests/data/video/test.mp4"

# there are 48 frames at 24 fps in this video file
res = extract_frames_from_video(video_path)
assert len(res) == 2
res = extract_frames_from_video(video_path, fps=24)
assert len(res) == 48
1 change: 1 addition & 0 deletions vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .vision_agent import VisionAgent
from .vision_agent_coder import (
AzureVisionAgentCoder,
ClaudeVisionAgentCoder,
OllamaVisionAgentCoder,
VisionAgentCoder,
)
48 changes: 30 additions & 18 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,48 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
if match:
json_str = match.group()
try:
# remove trailing comma
trailing_bracket_pattern = r",\s+\}"
json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL)

json_dict = json.loads(json_str)
return json_dict # type: ignore
except json.JSONDecodeError:
return None
return None


def _find_markdown_json(json_str: str) -> str:
pattern = r"```json(.*?)```"
match = re.search(pattern, json_str, re.DOTALL)
if match:
return match.group(1).strip()
return json_str


def _strip_markdown_code(inp_str: str) -> str:
pattern = r"```python.*?```"
cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL)
return cleaned_str


def extract_json(json_str: str) -> Dict[str, Any]:
json_str = json_str.replace("\n", " ").strip()

try:
json_str = json_str.replace("\n", " ")
json_dict = json.loads(json_str)
return json.loads(json_str) # type: ignore
except json.JSONDecodeError:
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
json_dict = _extract_sub_json(json_str)
if json_dict is not None:
return json_dict # type: ignore
error_msg = f"Could not extract JSON from the given str: {json_str}"
json_orig = json_str
json_str = _strip_markdown_code(json_str)
json_str = _find_markdown_json(json_str)
json_dict = _extract_sub_json(json_str)

if json_dict is None:
error_msg = f"Could not extract JSON from the given str: {json_orig}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e
raise ValueError(error_msg)

return json_dict # type: ignore
return json_dict


def extract_code(code: str) -> str:
Expand Down
Loading

0 comments on commit 777b4d5

Please sign in to comment.