Skip to content

Commit

Permalink
LMM planner and Updated OCR (#108)
Browse files Browse the repository at this point in the history
* added LMM planner

* updated OCR tool desc

* isort

* fixed type error

* moved back to llm for planner

* small fixes to prompt

* updated types

* revert tools doc

* added bmp
  • Loading branch information
dillonalaird authored Jun 5, 2024
1 parent 1b928a3 commit 5c1ee70
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/custom_tools/run_custom_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from template_match import template_matching_with_rotation

import vision_agent as va
from template_match import template_matching_with_rotation
from vision_agent.utils.image_utils import get_image_size, normalize_bbox


Expand Down
34 changes: 30 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import json
import logging
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from PIL import Image
from rich.console import Console
from rich.style import Style
from rich.syntax import Syntax
Expand Down Expand Up @@ -78,12 +80,35 @@ def extract_json(json_str: str) -> Dict[str, Any]:
return json_dict # type: ignore


def extract_image(
media: Optional[Sequence[Union[str, Path]]]
) -> Optional[Sequence[Union[str, Path]]]:
if media is None:
return None

new_media = []
for m in media:
m = Path(m)
extension = m.suffix
if extension in [".jpg", ".jpeg", ".png", ".bmp"]:
new_media.append(m)
elif extension in [".mp4", ".mov"]:
frames = T.extract_frames(m)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
if len(frames) > 0:
Image.fromarray(frames[0][0]).save(tmp.name)
new_media.append(Path(tmp.name))
if len(new_media) == 0:
return None
return new_media


def write_plan(
chat: List[Dict[str, str]],
tool_desc: str,
working_memory: str,
model: Union[LLM, LMM],
media: Optional[List[Union[str, Path]]] = None,
media: Optional[Sequence[Union[str, Path]]] = None,
) -> List[Dict[str, str]]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand All @@ -94,6 +119,7 @@ def write_plan(
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
chat[-1]["content"] = prompt
if isinstance(model, OpenAILMM):
media = extract_image(media)
return extract_json(model.chat(chat, images=media))["plan"] # type: ignore
else:
return extract_json(model.chat(chat))["plan"] # type: ignore
Expand All @@ -103,7 +129,7 @@ def reflect(
chat: List[Dict[str, str]],
plan: str,
code: str,
model: LLM,
model: Union[LLM, LMM],
) -> Dict[str, Union[str, bool]]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand Down Expand Up @@ -309,7 +335,7 @@ class VisionAgent(Agent):

def __init__(
self,
planner: Optional[LLM] = None,
planner: Optional[Union[LLM, LMM]] = None,
coder: Optional[LLM] = None,
tester: Optional[LLM] = None,
debugger: Optional[LLM] = None,
Expand Down
7 changes: 5 additions & 2 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
{feedback}
**Instructions**:
Based on the context and tools you have available, write a plan of subtasks to achieve the user request utilizing given tools when necessary. Output a list of jsons in the following format:
1. Based on the context and tools you have available, write a plan of subtasks to achieve the user request.
2. Go over the users request step by step and ensure each step is represented as a clear subtask in your plan.
Output a list of jsons in the following format
```json
{{
"plan":
[
{{
"instructions": str # what you should do in this task, one short phrase or sentence
"instructions": str # what you should do in this task associated with a tool
}}
]
}}
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,15 @@ def extract_frames(

def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
"""'ocr' extracts text from an image. It returns a list of detected text, bounding
boxes, and confidence scores. The results are sorted from top-left to bottom right
boxes with normalized coordinates, and confidence scores. The results are sorted
from top-left to bottom right.
Parameters:
image (np.ndarray): The image to extract text from.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox,
and confidence score.
List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox
with nornmalized coordinates, and confidence score.
Example
-------
Expand Down

0 comments on commit 5c1ee70

Please sign in to comment.