Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Claude Sonnet 3.5 VisionAgentCoder #231

Merged
merged 22 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 35 additions & 57 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ pillow-heif = "^0.16.0"
pytube = "15.0.0"
anthropic = "^0.31.0"
pydantic = "2.7.4"
eva-decord = "^0.6.1"
av = "^11.0.0"

[tool.poetry.group.dev.dependencies]
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
ixc25_temporal_localization,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
Expand Down Expand Up @@ -238,6 +239,17 @@ def test_ixc25_video_vqa() -> None:
assert "cat" in result.strip()


def test_ixc25_temporal_localization() -> None:
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
result = ixc25_temporal_localization(
prompt="What animal is in this video?",
frames=frames,
)
assert result == [True] * 10


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/tools/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ def test_extract_frames_from_video():
video_path = "tests/data/video/test.mp4"

# there are 48 frames at 24 fps in this video file
res = extract_frames_from_video(video_path)
assert len(res) == 2
res = extract_frames_from_video(video_path, fps=24)
assert len(res) == 48
1 change: 1 addition & 0 deletions vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .vision_agent import VisionAgent
from .vision_agent_coder import (
AzureVisionAgentCoder,
ClaudeVisionAgentCoder,
OllamaVisionAgentCoder,
VisionAgentCoder,
)
48 changes: 30 additions & 18 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,48 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
if match:
json_str = match.group()
try:
# remove trailing comma
trailing_bracket_pattern = r",\s+\}"
json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL)

json_dict = json.loads(json_str)
return json_dict # type: ignore
except json.JSONDecodeError:
return None
return None


def _find_markdown_json(json_str: str) -> str:
pattern = r"```json(.*?)```"
match = re.search(pattern, json_str, re.DOTALL)
if match:
return match.group(1).strip()
return json_str


def _strip_markdown_code(inp_str: str) -> str:
pattern = r"```python.*?```"
cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL)
return cleaned_str


def extract_json(json_str: str) -> Dict[str, Any]:
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
json_str = json_str.replace("\n", " ").strip()

try:
json_str = json_str.replace("\n", " ")
json_dict = json.loads(json_str)
return json.loads(json_str) # type: ignore
except json.JSONDecodeError:
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
json_dict = _extract_sub_json(json_str)
if json_dict is not None:
return json_dict # type: ignore
error_msg = f"Could not extract JSON from the given str: {json_str}"
json_orig = json_str
json_str = _strip_markdown_code(json_str)
json_str = _find_markdown_json(json_str)
json_dict = _extract_sub_json(json_str)

if json_dict is None:
error_msg = f"Could not extract JSON from the given str: {json_orig}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e
raise ValueError(error_msg)

return json_dict # type: ignore
return json_dict


def extract_code(code: str) -> str:
Expand Down
Loading
Loading