Skip to content

Commit

Permalink
feat: when code edited by code executed directly (#242)
Browse files Browse the repository at this point in the history
* feat: when code edited by code executed directly

* fix lint

* fix lint

* address comment
  • Loading branch information
wuyiqunLu authored Sep 22, 2024
1 parent bf87168 commit 2126b06
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
27 changes: 27 additions & 0 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def chat_with_code(
) as code_interpreter:
orig_chat = copy.deepcopy(chat)
int_chat = copy.deepcopy(chat)
last_user_message_content = chat[-1].get("content")
media_list = []
for chat_i in int_chat:
if "media" in chat_i:
Expand Down Expand Up @@ -266,6 +267,32 @@ def chat_with_code(
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})

if isinstance(last_user_message_content, str):
user_code_action = parse_execution(last_user_message_content, False)
if user_code_action is not None:
user_result, user_obs = run_code_action(
user_code_action, code_interpreter, str(remote_artifacts_path)
)
if self.verbosity >= 1:
_LOGGER.info(user_obs)
int_chat.append({"role": "observation", "content": user_obs})
orig_chat.append(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
}
)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": True,
}
)
finished = True

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
if self.verbosity >= 1:
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def edit_code_artifact(

total_lines = len(artifacts[name].splitlines())
if start < 0 or end < 0 or start > end or end > total_lines:
print("[Invalid line range]")
return "[Invalid line range]"
if start == end:
end += 1
Expand Down

0 comments on commit 2126b06

Please sign in to comment.