Skip to content

Commit

Permalink
add another prmopt example, reformat to reduce complex
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 15, 2024
1 parent d92e192 commit 250bbaa
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EXAMPLES_CODE1,
EXAMPLES_CODE2,
EXAMPLES_CODE3,
EXAMPLES_CODE3_EXTRA2,
VA_CODE,
)
from vision_agent.lmm import LMM, AnthropicLMM, Message, OpenAILMM
Expand Down Expand Up @@ -110,7 +111,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:

prompt = VA_CODE.format(
documentation=META_TOOL_DOCSTRING,
examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}",
examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}\n{EXAMPLES_CODE3_EXTRA2}",
conversation=conversation,
)
message: Message = {"role": "user", "content": prompt}
Expand Down Expand Up @@ -182,10 +183,46 @@ def execute_user_code_action(
)
if user_result.error:
user_obs += f"\n{user_result.error}"
extract_and_save_files_to_artifacts(artifacts, user_code_action, user_obs)
extract_and_save_files_to_artifacts(
artifacts, user_code_action, user_obs, user_result
)
return user_result, user_obs


def _add_media_obs(
code_action: str,
artifacts: Artifacts,
result: Execution,
obs: str,
code_interpreter: CodeInterpreter,
remote_artifacts_path: Path,
local_artifacts_path: Path,
) -> Dict[str, Any]:
obs_chat_elt: Message = {"role": "observation", "content": obs}
media_obs = check_and_load_image(code_action)
if media_obs and result.success:
# for view_media_artifact, we need to ensure the media is loaded
# locally so the conversation agent can actually see it
code_interpreter.download_file(
str(remote_artifacts_path.name),
str(local_artifacts_path),
)
artifacts.load(
local_artifacts_path,
local_artifacts_path.parent,
)

# check if the media is actually in the artifacts
media_obs_chat = []
for media_ob in media_obs:
if media_ob in artifacts.artifacts:
media_obs_chat.append(local_artifacts_path.parent / media_ob)
if len(media_obs_chat) > 0:
obs_chat_elt["media"] = media_obs_chat

return obs_chat_elt


def add_step_descriptions(response: Dict[str, Any]) -> Dict[str, Any]:
response = copy.deepcopy(response)

Expand Down Expand Up @@ -544,35 +581,19 @@ def chat_with_artifacts(
code_interpreter,
str(remote_artifacts_path),
)

media_obs = check_and_load_image(code_action)
obs_chat_elt = _add_media_obs(
code_action,
artifacts,
result,
obs,
code_interpreter,
Path(remote_artifacts_path),
Path(self.local_artifacts_path),
)

if self.verbosity >= 1:
_LOGGER.info(obs)

obs_chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
# for view_media_artifact, we need to ensure the media is loaded
# locally so the conversation agent can actually see it
code_interpreter.download_file(
str(remote_artifacts_path.name),
str(self.local_artifacts_path),
)
artifacts.load(
self.local_artifacts_path,
Path(self.local_artifacts_path).parent,
)

# check if the media is actually in the artifacts
media_obs_chat = []
for media_ob in media_obs:
if media_ob in artifacts.artifacts:
media_obs_chat.append(
Path(self.local_artifacts_path).parent / media_ob
)
if len(media_obs_chat) > 0:
obs_chat_elt["media"] = media_obs_chat

# don't add execution results to internal chat
int_chat.append(obs_chat_elt)
obs_chat_elt["execution"] = result
Expand Down

0 comments on commit 250bbaa

Please sign in to comment.