Skip to content

Commit

Permalink
Better code sandbox lifecycle management (#110)
Browse files Browse the repository at this point in the history
* Support local image visualization; Better code sandbox lifecycle management

* Fix lint errors

* Fix format

* Print logs from execution result

* Better format for execution outputs

* More consistent format for execution outputs
  • Loading branch information
humpydonkey authored Jun 4, 2024
1 parent 282a2ac commit 32e0156
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 135 deletions.
238 changes: 126 additions & 112 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
from vision_agent.utils.sim import Sim

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80
_EXECUTE = CodeInterpreterFactory.get_default_instance()
_CONSOLE = Console()
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

Expand Down Expand Up @@ -122,6 +123,7 @@ def write_and_test_code(
coder: LLM,
tester: LLM,
debugger: LLM,
code_interpreter: CodeInterpreter,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
max_retries: int = 3,
Expand Down Expand Up @@ -158,7 +160,7 @@ def write_and_test_code(
},
}
)
result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
log_progress(
{
"type": "code",
Expand All @@ -173,7 +175,7 @@ def write_and_test_code(
if verbosity == 2:
_print_code("Initial code and tests:", code, test)
_LOGGER.info(
f"Initial code execution result:\n{result.text(include_logs=False)}"
f"Initial code execution result:\n{result.text(include_logs=True)}"
)

count = 0
Expand Down Expand Up @@ -210,7 +212,7 @@ def write_and_test_code(
{"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]}
)

result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}")
log_progress(
{
"type": "code",
Expand All @@ -228,7 +230,7 @@ def write_and_test_code(
)
_print_code("Code and test after attempted fix:", code, test)
_LOGGER.info(
f"Code execution result after attempted fix: {result.text(include_logs=False)}"
f"Code execution result after attempted fix: {result.text(include_logs=True)}"
)
count += 1

Expand Down Expand Up @@ -377,6 +379,7 @@ def chat_with_workflow(
chat: List[Dict[str, str]],
media: Optional[Union[str, Path]] = None,
self_reflection: bool = False,
display_visualization: bool = False,
) -> Dict[str, Any]:
"""Chat with Vision Agent and return intermediate information regarding the task.
Expand All @@ -385,6 +388,7 @@ def chat_with_workflow(
[{"role": "user", "content": "describe your task here..."}].
media (Optional[Union[str, Path]]): The media file to be used in the task.
self_reflection (bool): Whether to reflect on the task and debug the code.
show_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any).
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
Expand All @@ -394,127 +398,137 @@ def chat_with_workflow(
if not chat:
raise ValueError("Chat cannot be empty.")

if media is not None:
media = _EXECUTE.upload_file(media)
for chat_i in chat:
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {media}"

# re-grab custom tools
global _DEFAULT_IMPORT
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

code = ""
test = ""
working_memory: List[Dict[str, str]] = []
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
retries = 0

while not success and retries < self.max_retries:
self.log_progress(
{
"type": "plans",
"status": "started",
}
)
plan_i = write_plan(
chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
media=[media] if media else None,
)
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plan_i,
}
)
if self.verbosity >= 1:

_LOGGER.info(
f"""
{tabulate(tabular_data=plan_i, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)

tool_info = retrieve_tools(
plan_i,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
results = write_and_test_code(
FULL_TASK.format(user_request=chat[0]["content"], subtasks=plan_i_str),
tool_info,
T.UTILITIES_DOCSTRING,
format_memory(working_memory),
self.coder,
self.tester,
self.debugger,
self.log_progress,
verbosity=self.verbosity,
input_media=media,
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if self_reflection:
# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
with CodeInterpreterFactory.new_instance() as code_interpreter:
if media is not None:
media = code_interpreter.upload_file(media)
for chat_i in chat:
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {media}"

# re-grab custom tools
global _DEFAULT_IMPORT
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)

code = ""
test = ""
working_memory: List[Dict[str, str]] = []
results = {"code": "", "test": "", "plan": []}
plan = []
success = False
retries = 0

while not success and retries < self.max_retries:
self.log_progress(
{
"type": "self_reflection",
"type": "plans",
"status": "started",
}
)
reflection = reflect(
plan_i = write_plan(
chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
media=[media] if media else None,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])

self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
"type": "plans",
"status": "completed",
"payload": plan_i,
}
)
working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})

retries += 1
if self.verbosity >= 1:
_LOGGER.info(
f"\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_info = retrieve_tools(
plan_i,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
results = write_and_test_code(
task=FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
tool_info=tool_info,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=format_memory(working_memory),
coder=self.coder,
tester=self.tester,
debugger=self.debugger,
code_interpreter=code_interpreter,
log_progress=self.log_progress,
verbosity=self.verbosity,
input_media=media,
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if self_reflection:
self.log_progress(
{
"type": "self_reflection",
"status": "started",
}
)
reflection = reflect(
chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
self.planner,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
}
)
working_memory.append(
{"code": f"{code}\n{test}", "feedback": feedback}
)

retries += 1

execution_result = cast(Execution, results["test_result"])
self.log_progress(
{
"type": "final_code",
"status": "completed" if success else "failed",
"payload": {
"code": code,
"test": test,
"result": execution_result.to_json(),
},
}
)

self.log_progress(
{
"type": "final_code",
"status": "completed" if success else "failed",
"payload": {
"code": code,
"test": test,
"result": cast(Execution, results["test_result"]).to_json(),
},
if display_visualization:
for res in execution_result.results:
if res.png:
b64_to_pil(res.png).show()
return {
"code": code,
"test": test,
"test_result": execution_result,
"plan": plan,
"working_memory": working_memory,
}
)

return {
"code": code,
"test": test,
"test_result": results["test_result"],
"plan": plan,
"working_memory": working_memory,
}

def log_progress(self, data: Dict[str, Any]) -> None:
if self.report_progress_callback is not None:
Expand Down
12 changes: 6 additions & 6 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def save_image(image: np.ndarray) -> str:
def overlay_bounding_boxes(
image: np.ndarray, bboxes: List[Dict[str, Any]]
) -> np.ndarray:
"""'display_bounding_boxes' is a utility function that displays bounding boxes on
"""'overlay_bounding_boxes' is a utility function that displays bounding boxes on
an image.
Parameters:
Expand All @@ -537,7 +537,7 @@ def overlay_bounding_boxes(
Example
-------
>>> image_with_bboxes = display_bounding_boxes(
>>> image_with_bboxes = overlay_bounding_boxes(
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
Expand Down Expand Up @@ -583,7 +583,7 @@ def overlay_bounding_boxes(
def overlay_segmentation_masks(
image: np.ndarray, masks: List[Dict[str, Any]]
) -> np.ndarray:
"""'display_segmentation_masks' is a utility function that displays segmentation
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Parameters:
Expand All @@ -595,7 +595,7 @@ def overlay_segmentation_masks(
Example
-------
>>> image_with_masks = display_segmentation_masks(
>>> image_with_masks = overlay_segmentation_masks(
image,
[{
'score': 0.99,
Expand Down Expand Up @@ -633,7 +633,7 @@ def overlay_segmentation_masks(
def overlay_heat_map(
image: np.ndarray, heat_map: Dict[str, Any], alpha: float = 0.8
) -> np.ndarray:
"""'display_heat_map' is a utility function that displays a heat map on an image.
"""'overlay_heat_map' is a utility function that displays a heat map on an image.
Parameters:
image (np.ndarray): The image to display the heat map on.
Expand All @@ -646,7 +646,7 @@ def overlay_heat_map(
Example
-------
>>> image_with_heat_map = display_heat_map(
>>> image_with_heat_map = overlay_heat_map(
image,
{
'heat_map': array([[0, 0, 0, ..., 0, 0, 0],
Expand Down
Loading

0 comments on commit 32e0156

Please sign in to comment.