Skip to content

Commit

Permalink
Fix missing typing module error and improve load_image() (#116)
Browse files Browse the repository at this point in the history
* Minor improvements

* Updates

* Update the final code

* Update the printed code

* Update the printed code

* Update the printed code

* Update return code
  • Loading branch information
humpydonkey authored Jun 5, 2024
1 parent d937ca0 commit daa4831
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 49 deletions.
110 changes: 64 additions & 46 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,25 @@
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80
_CONSOLE = Console()
_DEFAULT_IMPORT = "\n".join(T.__new_tools__) + "\n".join(
[


class DefaultImports:
"""Container for default imports used in the code execution."""

common_imports = [
"from typing import *",
]
)

@staticmethod
def to_code_string() -> str:
return "\n".join(DefaultImports.common_imports + T.__new_tools__)

@staticmethod
def prepend_imports(code: str) -> str:
"""Run this method to prepend the default imports to the code.
NOTE: be sure to run this method after the custom tools have been registered.
"""
return DefaultImports.to_code_string() + "\n\n" + code


def get_diff(before: str, after: str) -> str:
Expand Down Expand Up @@ -202,18 +216,20 @@ def write_and_test_code(
"type": "code",
"status": "running",
"payload": {
"code": code,
"code": DefaultImports.prepend_imports(code),
"test": test,
},
}
)
result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
result = code_interpreter.exec_isolation(
f"{DefaultImports.to_code_string()}\n{code}\n{test}"
)
log_progress(
{
"type": "code",
"status": "completed" if result.success else "failed",
"payload": {
"code": code,
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": result.to_json(),
},
Expand Down Expand Up @@ -264,19 +280,21 @@ def write_and_test_code(
"type": "code",
"status": "running",
"payload": {
"code": code,
"code": DefaultImports.prepend_imports(code),
"test": test,
},
}
)

result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
result = code_interpreter.exec_isolation(
f"{DefaultImports.to_code_string()}\n{code}\n{test}"
)
log_progress(
{
"type": "code",
"status": "completed" if result.success else "failed",
"payload": {
"code": code,
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": result.to_json(),
},
Expand Down Expand Up @@ -307,7 +325,14 @@ def write_and_test_code(
def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
_CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
_CONSOLE.print("=" * 30 + " Code " + "=" * 30)
_CONSOLE.print(Syntax(code, "python", theme="gruvbox-dark", line_numbers=True))
_CONSOLE.print(
Syntax(
DefaultImports.prepend_imports(code),
"python",
theme="gruvbox-dark",
line_numbers=True,
)
)
if test:
_CONSOLE.print("=" * 30 + " Test " + "=" * 30)
_CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
Expand Down Expand Up @@ -464,10 +489,6 @@ def chat_with_workflow(
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {media}"

# re-grab custom tools
global _DEFAULT_IMPORT
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

code = ""
test = ""
working_memory: List[Dict[str, str]] = []
Expand Down Expand Up @@ -531,38 +552,35 @@ def chat_with_workflow(
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if self_reflection:
self.log_progress(
{
"type": "self_reflection",
"status": "started",
}
)
reflection = reflect(
chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
self.planner,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
}
)
working_memory.append(
{"code": f"{code}\n{test}", "feedback": feedback}
)
else:
if not self_reflection:
break

self.log_progress(
{
"type": "self_reflection",
"status": "started",
}
)
reflection = reflect(
chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
self.planner,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
}
)
working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})
retries += 1

execution_result = cast(Execution, results["test_result"])
Expand All @@ -571,7 +589,7 @@ def chat_with_workflow(
"type": "final_code",
"status": "completed" if success else "failed",
"payload": {
"code": code,
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": execution_result.to_json(),
},
Expand All @@ -586,7 +604,7 @@ def chat_with_workflow(
play_video(res.mp4)

return {
"code": code,
"code": DefaultImports.prepend_imports(code),
"test": test,
"test_result": execution_result,
"plan": plan,
Expand Down
8 changes: 5 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def extract_frames(
Returns:
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
and the timestamp in seconds.
as a numpy array and the timestamp in seconds.
Example
-------
Expand Down Expand Up @@ -515,7 +515,7 @@ def default(self, obj: Any): # type: ignore


def load_image(image_path: str) -> np.ndarray:
"""'load_image' is a utility function that loads an image from the given path.
"""'load_image' is a utility function that loads an image from the given file path string.
Parameters:
image_path (str): The path to the image.
Expand All @@ -527,7 +527,9 @@ def load_image(image_path: str) -> np.ndarray:
-------
>>> load_image("path/to/image.jpg")
"""

# NOTE: sometimes the generated code pass in a NumPy array
if isinstance(image_path, np.ndarray):
return image_path
image = Image.open(image_path).convert("RGB")
return np.array(image)

Expand Down

0 comments on commit daa4831

Please sign in to comment.