Skip to content

Commit

Permalink
fixed user exec obs
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 26, 2024
1 parent 2f05f18 commit 6ca3a56
Showing 1 changed file with 47 additions and 38 deletions.
85 changes: 47 additions & 38 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
return extract_json(orch([message], stream=False)) # type: ignore


def run_code_action(
def execute_code_action(
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
Expand Down Expand Up @@ -115,10 +115,33 @@ def parse_execution(
return code


def execute_user_code_action(
last_user_message: Message,
code_interpreter: CodeInterpreter,
artifact_remote_path: str,
) -> Tuple[Optional[Execution], Optional[str]]:
user_result = None
user_obs = None

if last_user_message["role"] != "user":
return user_result, user_obs

last_user_content = cast(str, last_user_message.get("content", ""))

user_code_action = parse_execution(last_user_content, False)
if user_code_action is not None:
user_result, user_obs = execute_code_action(
user_code_action, code_interpreter, artifact_remote_path
)
if user_result.error:
user_obs += f"\n{user_result.error}"
return user_result, user_obs


class VisionAgent(Agent):
"""Vision Agent is an agent that can chat with the user and call tools or other
agents to generate code for it. Vision Agent uses python code to execute actions
for the user. Vision Agent is inspired by by OpenDev
for the user. Vision Agent is inspired by by OpenDevin
https://github.com/OpenDevin/OpenDevin and CodeAct https://arxiv.org/abs/2402.01030
Example
Expand Down Expand Up @@ -278,9 +301,23 @@ def chat_with_code(
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})

finished = self.execute_user_code_action(
last_user_message, code_interpreter, remote_artifacts_path
user_result, user_obs = execute_user_code_action(
last_user_message, code_interpreter, str(remote_artifacts_path)
)
finished = user_result is not None and user_obs is not None
if user_result is not None and user_obs is not None:
chat_elt: Message = {"role": "observation", "content": user_obs}
int_chat.append(chat_elt)
chat_elt["execution"] = user_result
orig_chat.append(chat_elt)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": finished,
}
)

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
Expand Down Expand Up @@ -322,7 +359,7 @@ def chat_with_code(
)

if code_action is not None:
result, obs = run_code_action(
result, obs = execute_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
)

Expand All @@ -331,17 +368,17 @@ def chat_with_code(
if self.verbosity >= 1:
_LOGGER.info(obs)

chat_elt: Message = {"role": "observation", "content": obs}
obs_chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
chat_elt["media"] = [
obs_chat_elt["media"] = [
Path(code_interpreter.remote_path) / media_ob
for media_ob in media_obs
]

# don't add execution results to internal chat
int_chat.append(chat_elt)
chat_elt["execution"] = result
orig_chat.append(chat_elt)
int_chat.append(obs_chat_elt)
obs_chat_elt["execution"] = result
orig_chat.append(obs_chat_elt)
self.streaming_message(
{
"role": "observation",
Expand All @@ -362,34 +399,6 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts

def execute_user_code_action(
self,
last_user_message: Message,
code_interpreter: CodeInterpreter,
remote_artifacts_path: Path,
) -> bool:
if last_user_message["role"] != "user":
return False
user_code_action = parse_execution(
cast(str, last_user_message.get("content", "")), False
)
if user_code_action is not None:
user_result, user_obs = run_code_action(
user_code_action, code_interpreter, str(remote_artifacts_path)
)
if self.verbosity >= 1:
_LOGGER.info(user_obs)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": True,
}
)
return True
return False

def streaming_message(self, message: Dict[str, Any]) -> None:
if self.callback_message:
self.callback_message(message)
Expand Down

0 comments on commit 6ca3a56

Please sign in to comment.