Skip to content

Commit

Permalink
fixed bug with edit vision code
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 26, 2024
1 parent c21ecfb commit e827bde
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy as np
from IPython.display import display

import vision_agent as va
Expand All @@ -17,7 +18,8 @@
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
from vision_agent.utils.execute import Execution, MimeType
from vision_agent.utils.image_utils import convert_to_b64
from vision_agent.utils.image_utils import convert_to_b64, numpy_to_bytes
from vision_agent.utils.video import frames_to_bytes

# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

Expand Down Expand Up @@ -432,12 +434,14 @@ def detect_dogs(image_path: str):

# Append latest code to second to last message from assistant
fixed_chat_history: List[Message] = []
user_message = "Previous user requests:"
for i, chat in enumerate(chat_history):
if i == 0:
fixed_chat_history.append({"role": "user", "content": chat, "media": media})
elif i > 0 and i < len(chat_history) - 1:
fixed_chat_history.append({"role": "user", "content": chat})
elif i == len(chat_history) - 1:
if i < len(chat_history) - 1:
user_message += " " + chat
else:
fixed_chat_history.append(
{"role": "user", "content": user_message, "media": media}
)
fixed_chat_history.append({"role": "assistant", "content": code})
fixed_chat_history.append({"role": "user", "content": chat})

Expand Down Expand Up @@ -467,17 +471,34 @@ def detect_dogs(image_path: str):
return view_lines(code_lines, 0, total_lines, name, total_lines)


def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
def write_media_artifact(
artifacts: Artifacts,
name: str,
media: Union[str, np.ndarray, List[np.ndarray]],
fps: Optional[float] = None,
) -> str:
"""Writes a media file to the artifacts object.
Parameters:
artifacts (Artifacts): The artifacts object to save the media to.
local_path (str): The local path to the media file.
name (str): The name of the media artifact to save.
media (Union[str, np.ndarray, List[np.ndarray]]): The media to save, can either
be a file path, single image or list of frames for a video.
fps (Optional[float]): The frames per second if you are writing a video.
"""
with open(local_path, "rb") as f:
media = f.read()
artifacts[Path(local_path).name] = media
return f"[Media {Path(local_path).name} saved]"
if isinstance(media, str):
with open(media, "rb") as f:
media_bytes = f.read()
elif isinstance(media, list):
media_bytes = frames_to_bytes(media, fps=fps if fps is not None else 1.0)
elif isinstance(media, np.ndarray):
media_bytes = numpy_to_bytes(media)
else:
print(f"[Invalid media type {type(media)}]")
return f"[Invalid media type {type(media)}]"
artifacts[name] = media_bytes
print(f"[Media {name} saved]")
return f"[Media {name} saved]"


def list_artifacts(artifacts: Artifacts) -> str:
Expand All @@ -500,7 +521,8 @@ def check_and_load_image(code: str) -> List[str]:


def view_media_artifact(artifacts: Artifacts, name: str) -> str:
"""Views the image artifact with the given name.
"""Allows you to view the media artifact with the given name. This does not show
the media to the user, the user can already see all media saved in the artifacts.
Parameters:
artifacts (Artifacts): The artifacts object to show the image from.
Expand Down Expand Up @@ -684,9 +706,10 @@ def use_object_detection_fine_tuning(
pattern_with_fine_tune_id, replacer_with_fine_tune_id, new_code
)
else:
(pattern_without_fine_tune_id, replacer_without_fine_tune_id) = (
patterns_without_fine_tune_id[index]
)
(
pattern_without_fine_tune_id,
replacer_without_fine_tune_id,
) = patterns_without_fine_tune_id[index]
new_code = re.sub(
pattern_without_fine_tune_id, replacer_without_fine_tune_id, new_code
)
Expand Down

0 comments on commit e827bde

Please sign in to comment.