Skip to content

Commit

Permalink
Support local image visualization; Better code sandbox lifecycle mana…
Browse files Browse the repository at this point in the history
…gement
  • Loading branch information
humpydonkey committed Jun 3, 2024
1 parent 282a2ac commit 8087d8a
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 128 deletions.
235 changes: 125 additions & 110 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
from vision_agent.utils.sim import Sim

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80
_EXECUTE = CodeInterpreterFactory.get_default_instance()
_CONSOLE = Console()
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

Expand Down Expand Up @@ -122,6 +123,7 @@ def write_and_test_code(
coder: LLM,
tester: LLM,
debugger: LLM,
code_interpreter: CodeInterpreter,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
max_retries: int = 3,
Expand Down Expand Up @@ -158,7 +160,7 @@ def write_and_test_code(
},
}
)
result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
log_progress(
{
"type": "code",
Expand All @@ -175,6 +177,7 @@ def write_and_test_code(
_LOGGER.info(
f"Initial code execution result:\n{result.text(include_logs=False)}"
)
breakpoint()

count = 0
new_working_memory = []
Expand Down Expand Up @@ -210,7 +213,7 @@ def write_and_test_code(
{"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]}
)

result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
log_progress(
{
"type": "code",
Expand Down Expand Up @@ -377,6 +380,7 @@ def chat_with_workflow(
chat: List[Dict[str, str]],
media: Optional[Union[str, Path]] = None,
self_reflection: bool = False,
display_visualization: bool = False,
) -> Dict[str, Any]:
"""Chat with Vision Agent and return intermediate information regarding the task.
Expand All @@ -385,6 +389,7 @@ def chat_with_workflow(
[{"role": "user", "content": "describe your task here..."}].
media (Optional[Union[str, Path]]): The media file to be used in the task.
self_reflection (bool): Whether to reflect on the task and debug the code.
show_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any).
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
Expand All @@ -394,127 +399,137 @@ def chat_with_workflow(
if not chat:
raise ValueError("Chat cannot be empty.")

if media is not None:
media = _EXECUTE.upload_file(media)
for chat_i in chat:
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {media}"

# re-grab custom tools
global _DEFAULT_IMPORT
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

code = ""
test = ""
working_memory: List[Dict[str, str]] = []
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
retries = 0

while not success and retries < self.max_retries:
self.log_progress(
{
"type": "plans",
"status": "started",
}
)
plan_i = write_plan(
chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
media=[media] if media else None,
)
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plan_i,
}
)
if self.verbosity >= 1:

_LOGGER.info(
f"""
{tabulate(tabular_data=plan_i, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)

tool_info = retrieve_tools(
plan_i,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
results = write_and_test_code(
FULL_TASK.format(user_request=chat[0]["content"], subtasks=plan_i_str),
tool_info,
T.UTILITIES_DOCSTRING,
format_memory(working_memory),
self.coder,
self.tester,
self.debugger,
self.log_progress,
verbosity=self.verbosity,
input_media=media,
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if self_reflection:
# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
with CodeInterpreterFactory.new_instance() as code_interpreter:
if media is not None:
media = code_interpreter.upload_file(media)
for chat_i in chat:
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {media}"

# re-grab custom tools
global _DEFAULT_IMPORT
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

code = ""
test = ""
working_memory: List[Dict[str, str]] = []
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
retries = 0

while not success and retries < self.max_retries:
self.log_progress(
{
"type": "self_reflection",
"type": "plans",
"status": "started",
}
)
reflection = reflect(
plan_i = write_plan(
chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
media=[media] if media else None,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])

self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
"type": "plans",
"status": "completed",
"payload": plan_i,
}
)
working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})

retries += 1
if self.verbosity >= 1:
_LOGGER.info(
f"\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_info = retrieve_tools(
plan_i,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
results = write_and_test_code(
task=FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
tool_info=tool_info,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=format_memory(working_memory),
coder=self.coder,
tester=self.tester,
debugger=self.debugger,
code_interpreter=code_interpreter,
log_progress=self.log_progress,
verbosity=self.verbosity,
input_media=media,
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if self_reflection:
self.log_progress(
{
"type": "self_reflection",
"status": "started",
}
)
reflection = reflect(
chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
self.planner,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
}
)
working_memory.append(
{"code": f"{code}\n{test}", "feedback": feedback}
)

retries += 1

execution_result = cast(Execution, results["test_result"])
self.log_progress(
{
"type": "final_code",
"status": "completed" if success else "failed",
"payload": {
"code": code,
"test": test,
"result": execution_result.to_json(),
},
}
)

self.log_progress(
{
"type": "final_code",
"status": "completed" if success else "failed",
"payload": {
"code": code,
"test": test,
"result": cast(Execution, results["test_result"]).to_json(),
},
if display_visualization:
for res in execution_result.results:
if res.png:
b64_to_pil(res.png).show()
return {
"code": code,
"test": test,
"test_result": execution_result,
"plan": plan,
"working_memory": working_memory,
}
)

return {
"code": code,
"test": test,
"test_result": results["test_result"],
"plan": plan,
"working_memory": working_memory,
}

def log_progress(self, data: Dict[str, Any]) -> None:
if self.report_progress_callback is not None:
Expand Down
Loading

0 comments on commit 8087d8a

Please sign in to comment.