Skip to content

Commit

Permalink
reduced code complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 16, 2024
1 parent 8161c48 commit dea8756
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 82 deletions.
127 changes: 47 additions & 80 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import pickle as pkl
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -122,7 +123,9 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
and "media" in chat[-1]
and len(chat[-1]["media"]) > 0 # type: ignore
):
message["media"] = chat[-1]["media"]
media_obs = [media for media in chat[-1]["media"] if Path(media).exists()] # type: ignore
if len(media_obs) > 0:
message["media"] = media_obs # type: ignore
conv_resp = cast(str, orch([message], stream=False))

# clean the response first, if we are executing code, do not resond or end
Expand All @@ -146,10 +149,11 @@ def execute_code_action(
artifacts: Artifacts,
code: str,
code_interpreter: CodeInterpreter,
artifact_remote_path: str,
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
BoilerplateCode.add_boilerplate(
code, remote_path=str(artifacts.remote_save_path)
)
)

obs = str(result.logs)
Expand All @@ -163,7 +167,6 @@ def execute_user_code_action(
artifacts: Artifacts,
last_user_message: Message,
code_interpreter: CodeInterpreter,
artifact_remote_path: str,
) -> Tuple[Optional[Execution], Optional[str]]:
user_result = None
user_obs = None
Expand All @@ -180,50 +183,28 @@ def execute_user_code_action(
if user_code_action is not None:
user_code_action = use_extra_vision_agent_args(user_code_action, False)
user_result, user_obs = execute_code_action(
artifacts, user_code_action, code_interpreter, artifact_remote_path
artifacts, user_code_action, code_interpreter
)
if user_result.error:
user_obs += f"\n{user_result.error}"
extract_and_save_files_to_artifacts(
artifacts, user_code_action, user_obs, user_result
)
return user_result, user_obs


def _add_media_obs(
code_action: str,
artifacts: Artifacts,
result: Execution,
obs: str,
code_interpreter: CodeInterpreter,
remote_artifacts_path: Path,
local_artifacts_path: Path,
) -> Dict[str, Any]:
obs_chat_elt: Message = {"role": "observation", "content": obs}
media_obs = check_and_load_image(code_action)
if media_obs and result.success:
# for view_media_artifact, we need to ensure the media is loaded locally so
# the conversation agent can actually see it. We also download it here so we
# can check if it contains the actual media (note this is in addition to
# downloading it per turn).
def download_and_merge_artifacts(
code_interpreter: CodeInterpreter, artifacts: Artifacts
) -> None:
with tempfile.TemporaryFile() as temp_file:
code_interpreter.download_file(
str(remote_artifacts_path.name),
str(local_artifacts_path),
)
artifacts.load(
local_artifacts_path,
local_artifacts_path.parent,
str(artifacts.remote_save_path),
str(temp_file),
)

# check if the media is actually in the artifacts
media_obs_chat = []
for media_ob in media_obs:
if media_ob in artifacts.artifacts:
media_obs_chat.append(local_artifacts_path.parent / media_ob)
if len(media_obs_chat) > 0:
obs_chat_elt["media"] = media_obs_chat

return obs_chat_elt
temp_file.seek(0)
with open(str(temp_file), "rb") as f:
remote_artifacts = pkl.load(f)
merged_artifacts = {**artifacts.artifacts, **remote_artifacts}
artifacts.artifacts = merged_artifacts
artifacts.save()
artifacts.load(artifacts.local_save_path, artifacts.local_save_path.parent)


def add_step_descriptions(response: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -354,21 +335,15 @@ def __init__(
self.callback_message = callback_message
if self.verbosity >= 1:
_LOGGER.setLevel(logging.INFO)
self.local_artifacts_path = cast(
str,
(
Path(local_artifacts_path)
if local_artifacts_path is not None
else Path(tempfile.NamedTemporaryFile(delete=False).name)
),
self.local_artifacts_path = (
Path(local_artifacts_path)
if local_artifacts_path is not None
else Path(tempfile.NamedTemporaryFile(delete=False).name)
)
self.remote_artifacts_path = cast(
str,
(
Path(remote_artifacts_path)
if remote_artifacts_path is not None
else Path(WORKSPACE / "artifacts.pkl")
),
self.remote_artifacts_path = (
Path(remote_artifacts_path)
if remote_artifacts_path is not None
else Path(WORKSPACE / "artifacts.pkl")
)

def __call__(
Expand Down Expand Up @@ -455,8 +430,15 @@ def chat_with_artifacts(
and not isinstance(self.code_interpreter, str)
else CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_interpreter,
remote_path=self.remote_artifacts_path.parent,
)
)

if code_interpreter.remote_path != self.remote_artifacts_path.parent:
raise ValueError(
f"Code interpreter remote path {code_interpreter.remote_path} does not match {self.remote_artifacts_path.parent}"
)

with code_interpreter:
orig_chat = copy.deepcopy(chat)
int_chat = copy.deepcopy(chat)
Expand Down Expand Up @@ -501,9 +483,7 @@ def chat_with_artifacts(
# Upload artifacts to remote location and show where they are going
# to be loaded to. The actual loading happens in BoilerplateCode as
# part of the pre_code.
remote_artifacts_path = code_interpreter.upload_file(
self.local_artifacts_path
)
code_interpreter.upload_file(self.local_artifacts_path)
artifacts_loaded = artifacts.show(code_interpreter.remote_path)
int_chat.append({"role": "observation", "content": artifacts_loaded})
orig_chat.append({"role": "observation", "content": artifacts_loaded})
Expand All @@ -513,7 +493,6 @@ def chat_with_artifacts(
artifacts,
last_user_message,
code_interpreter,
str(remote_artifacts_path),
)
finished = user_result is not None and user_obs is not None
if user_result is not None and user_obs is not None:
Expand All @@ -537,6 +516,11 @@ def chat_with_artifacts(
code_interpreter.upload_file(self.local_artifacts_path)

response = run_conversation(self.agent, int_chat)
code_action = use_extra_vision_agent_args(
response.get("execute_python", None),
test_multi_plan,
custom_tool_names,
)
if self.verbosity >= 1:
_LOGGER.info(response)
int_chat.append(
Expand All @@ -562,12 +546,6 @@ def chat_with_artifacts(

finished = response.get("let_user_respond", False)

code_action = response.get("execute_python", None)
if code_action is not None:
code_action = use_extra_vision_agent_args(
code_action, test_multi_plan, custom_tool_names
)

if last_response == response:
self.streaming_message(
{
Expand Down Expand Up @@ -597,17 +575,11 @@ def chat_with_artifacts(
artifacts,
code_action,
code_interpreter,
str(remote_artifacts_path),
)
obs_chat_elt = _add_media_obs(
code_action,
artifacts,
result,
obs,
code_interpreter,
Path(remote_artifacts_path),
Path(self.local_artifacts_path),
)
obs_chat_elt: Message = {"role": "observation", "content": obs}
media_obs = check_and_load_image(code_action)
if media_obs and result.success:
obs_chat_elt["media"] = media_obs

if self.verbosity >= 1:
_LOGGER.info(obs)
Expand All @@ -629,12 +601,7 @@ def chat_with_artifacts(
last_response = response

# after each turn, download the artifacts locally
code_interpreter.download_file(
str(remote_artifacts_path.name), str(self.local_artifacts_path)
)
artifacts.load(
self.local_artifacts_path, Path(self.local_artifacts_path).parent
)
download_and_merge_artifacts(code_interpreter, artifacts)

return orig_chat, artifacts

Expand Down
6 changes: 4 additions & 2 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,10 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str:


def use_extra_vision_agent_args(
code: str,
code: Optional[str],
test_multi_plan: bool = True,
custom_tool_names: Optional[List[str]] = None,
) -> str:
) -> Optional[str]:
"""This is for forcing arguments passed by the user to VisionAgent into the
VisionAgentCoder call.
Expand All @@ -677,6 +677,8 @@ def use_extra_vision_agent_args(
Returns:
str: The edited code.
"""
if code is None:
return None
red = RedBaron(code)
for node in red:
# seems to always be atomtrailers not call type
Expand Down

0 comments on commit dea8756

Please sign in to comment.