Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LMM planner and Updated OCR #108

Merged
merged 9 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/custom_tools/run_custom_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from template_match import template_matching_with_rotation

import vision_agent as va
from template_match import template_matching_with_rotation
from vision_agent.utils.image_utils import get_image_size, normalize_bbox


Expand Down
34 changes: 30 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import json
import logging
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from PIL import Image
from rich.console import Console
from rich.style import Style
from rich.syntax import Syntax
Expand Down Expand Up @@ -76,12 +78,35 @@ def extract_json(json_str: str) -> Dict[str, Any]:
return json_dict # type: ignore


def extract_image(
media: Optional[Sequence[Union[str, Path]]]
) -> Optional[Sequence[Union[str, Path]]]:
if media is None:
return None

new_media = []
for m in media:
m = Path(m)
extension = m.suffix
if extension in [".jpg", ".jpeg", ".png", ".bmp"]:
new_media.append(m)
elif extension in [".mp4", ".mov"]:
frames = T.extract_frames(m)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
if len(frames) > 0:
Image.fromarray(frames[0][0]).save(tmp.name)
new_media.append(Path(tmp.name))
if len(new_media) == 0:
return None
return new_media


def write_plan(
chat: List[Dict[str, str]],
tool_desc: str,
working_memory: str,
model: Union[LLM, LMM],
media: Optional[List[Union[str, Path]]] = None,
media: Optional[Sequence[Union[str, Path]]] = None,
) -> List[Dict[str, str]]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand All @@ -92,6 +117,7 @@ def write_plan(
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
chat[-1]["content"] = prompt
if isinstance(model, OpenAILMM):
media = extract_image(media)
return extract_json(model.chat(chat, images=media))["plan"] # type: ignore
else:
return extract_json(model.chat(chat))["plan"] # type: ignore
Expand All @@ -101,7 +127,7 @@ def reflect(
chat: List[Dict[str, str]],
plan: str,
code: str,
model: LLM,
model: Union[LLM, LMM],
) -> Dict[str, Union[str, bool]]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand Down Expand Up @@ -306,7 +332,7 @@ class VisionAgent(Agent):

def __init__(
self,
planner: Optional[LLM] = None,
planner: Optional[Union[LLM, LMM]] = None,
coder: Optional[LLM] = None,
tester: Optional[LLM] = None,
debugger: Optional[LLM] = None,
Expand Down
7 changes: 5 additions & 2 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
{feedback}

**Instructions**:
Based on the context and tools you have available, write a plan of subtasks to achieve the user request utilizing given tools when necessary. Output a list of jsons in the following format:
1. Based on the context and tools you have available, write a plan of subtasks to achieve the user request.
2. Go over the users request step by step and ensure each step is represented as a clear subtask in your plan.

Output a list of jsons in the following format

```json
{{
"plan":
[
{{
"instructions": str # what you should do in this task, one short phrase or sentence
"instructions": str # what you should do in this task associated with a tool
}}
]
}}
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,15 @@ def extract_frames(

def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
"""'ocr' extracts text from an image. It returns a list of detected text, bounding
boxes, and confidence scores. The results are sorted from top-left to bottom right
boxes with normalized coordinates, and confidence scores. The results are sorted
from top-left to bottom right.

Parameters:
image (np.ndarray): The image to extract text from.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox,
and confidence score.
List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox
with nornmalized coordinates, and confidence score.

Example
-------
Expand Down
Loading