Skip to content

Commit

Permalink
feat: error handling for conversation agent (#246)
Browse files Browse the repository at this point in the history
* feat: error handling for conversation agent

* address comment

* fix the issue last user message is not legit

* fix lint
  • Loading branch information
wuyiqunLu authored Sep 24, 2024
1 parent a663a74 commit f80241b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 30 deletions.
92 changes: 62 additions & 30 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
if str(WORKSPACE) != "":
os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}"

STUCK_IN_LOOP_ERROR_MESSAGE = {
"name": "Error when running conversation agent",
"value": "Agent is stuck in conversation loop, exited",
"traceback_raw": [],
}


class BoilerplateCode:
pre_code = [
Expand Down Expand Up @@ -229,7 +235,7 @@ def chat_with_code(
) as code_interpreter:
orig_chat = copy.deepcopy(chat)
int_chat = copy.deepcopy(chat)
last_user_message_content = chat[-1].get("content")
last_user_message = chat[-1]
media_list = []
for chat_i in int_chat:
if "media" in chat_i:
Expand Down Expand Up @@ -278,32 +284,9 @@ def chat_with_code(
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})

if int_chat[-1]["role"] == "user":
last_user_message_content = cast(str, int_chat[-1].get("content", ""))
user_code_action = parse_execution(last_user_message_content, False)
if user_code_action is not None:
user_result, user_obs = run_code_action(
user_code_action, code_interpreter, str(remote_artifacts_path)
)
if self.verbosity >= 1:
_LOGGER.info(user_obs)
int_chat.append({"role": "observation", "content": user_obs})
orig_chat.append(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
}
)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": True,
}
)
finished = True
finished = self.execute_user_code_action(
last_user_message, code_interpreter, remote_artifacts_path
)

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
Expand All @@ -316,17 +299,37 @@ def chat_with_code(
if last_response == response:
response["let_user_respond"] = True
self.streaming_message(
{"role": "assistant", "error": "Stuck in loop"}
{
"role": "assistant",
"content": "{}",
"error": STUCK_IN_LOOP_ERROR_MESSAGE,
}
)
else:
self.streaming_message({"role": "assistant", "content": response})

finished = response["let_user_respond"]

code_action = parse_execution(
response["response"], test_multi_plan, customized_tool_names
)

if last_response == response:
self.streaming_message(
{
"role": "assistant",
"content": "{}",
"error": STUCK_IN_LOOP_ERROR_MESSAGE,
"finished": finished and code_action is None,
}
)
else:
self.streaming_message(
{
"role": "assistant",
"content": response,
"finished": finished and code_action is None,
}
)

if code_action is not None:
result, obs = run_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
Expand All @@ -353,6 +356,7 @@ def chat_with_code(
"role": "observation",
"content": obs,
"execution": result,
"finished": finished,
}
)

Expand All @@ -367,6 +371,34 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts

def execute_user_code_action(
self,
last_user_message: Message,
code_interpreter: CodeInterpreter,
remote_artifacts_path: Path,
) -> bool:
if last_user_message["role"] != "user":
return False
user_code_action = parse_execution(
cast(str, last_user_message.get("content", "")), False
)
if user_code_action is not None:
user_result, user_obs = run_code_action(
user_code_action, code_interpreter, str(remote_artifacts_path)
)
if self.verbosity >= 1:
_LOGGER.info(user_obs)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": True,
}
)
return True
return False

def streaming_message(self, message: Dict[str, Any]) -> None:
if self.callback_message:
self.callback_message(message)
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def detect_dogs(image_path: str):

agent = va.agent.VisionAgentCoder()
if name not in artifacts:
print(f"[Artifact {name} does not exist]")
return f"[Artifact {name} does not exist]"

code = artifacts[name]
Expand Down

0 comments on commit f80241b

Please sign in to comment.