Skip to content

Commit

Permalink
Merge branch 'main' into full-claude-35-support
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird authored Sep 17, 2024
2 parents c86959e + a56927d commit 5a8f6a9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.2.133"
version = "0.2.135"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand Down
19 changes: 15 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
VA_CODE,
)
from vision_agent.lmm import LMM, AnthropicLMM, Message, OpenAILMM
from vision_agent.tools import META_TOOL_DOCSTRING, load_image, save_image
from vision_agent.tools import META_TOOL_DOCSTRING, extract_frames, load_image, save_image
from vision_agent.tools.meta_tools import (
Artifacts,
check_and_load_image,
Expand Down Expand Up @@ -235,9 +235,20 @@ def chat_with_code(
for media in chat_i["media"]:
if type(media) is str and media.startswith(("http", "https")):
# TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later
file_path = Path(media).name
ndarray = load_image(media)
save_image(ndarray, file_path)
file_path = str(
Path(self.local_artifacts_path).parent
/ Path(media).name
)
if file_path.lower().endswith(
".mp4"
) or file_path.lower().endswith(".mov"):
video_frames = extract_frames(media)
save_video(
[frame for frame, _ in video_frames], file_path
)
else:
ndarray = load_image(media)
save_image(ndarray, file_path)
media = file_path
else:
media = cast(str, media)
Expand Down
33 changes: 19 additions & 14 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,27 @@ def redisplay_results(execution: Execution) -> None:
"""
for result in execution.results:
if result.text is not None:
display({MimeType.TEXT_PLAIN: result.text})
display({MimeType.TEXT_PLAIN: result.text}, raw=True)
if result.html is not None:
display({MimeType.TEXT_HTML: result.html})
display({MimeType.TEXT_HTML: result.html}, raw=True)
if result.markdown is not None:
display({MimeType.TEXT_MARKDOWN: result.markdown})
display({MimeType.TEXT_MARKDOWN: result.markdown}, raw=True)
if result.svg is not None:
display({MimeType.IMAGE_SVG: result.svg})
display({MimeType.IMAGE_SVG: result.svg}, raw=True)
if result.png is not None:
display({MimeType.IMAGE_PNG: result.png})
display({MimeType.IMAGE_PNG: result.png}, raw=True)
if result.jpeg is not None:
display({MimeType.IMAGE_JPEG: result.jpeg})
display({MimeType.IMAGE_JPEG: result.jpeg}, raw=True)
if result.mp4 is not None:
display({MimeType.VIDEO_MP4_B64: result.mp4})
display({MimeType.VIDEO_MP4_B64: result.mp4}, raw=True)
if result.latex is not None:
display({MimeType.TEXT_LATEX: result.latex})
display({MimeType.TEXT_LATEX: result.latex}, raw=True)
if result.json is not None:
display({MimeType.APPLICATION_JSON: result.json})
display({MimeType.APPLICATION_JSON: result.json}, raw=True)
if result.artifact_name is not None:
display({MimeType.TEXT_ARTIFACT_NAME: result.artifact_name}, raw=True)
if result.extra is not None:
display(result.extra)
display(result.extra, raw=True)


class Artifacts:
Expand Down Expand Up @@ -208,7 +210,7 @@ def create_code_artifact(artifacts: Artifacts, name: str) -> str:
return_str = f"[Artifact {name} created]"
print(return_str)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return return_str


Expand Down Expand Up @@ -292,7 +294,7 @@ def edit_code_artifact(

artifacts[name] = "".join(edited_lines)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return open_code_artifact(artifacts, name, cur_line)


Expand Down Expand Up @@ -348,7 +350,7 @@ def detect_dogs(image_path: str):
code_lines = code.splitlines(keepends=True)
total_lines = len(code_lines)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return view_lines(code_lines, 0, total_lines, name, total_lines)


Expand Down Expand Up @@ -413,7 +415,7 @@ def detect_dogs(image_path: str):
code_lines = code.splitlines(keepends=True)
total_lines = len(code_lines)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return view_lines(code_lines, 0, total_lines, name, total_lines)


Expand All @@ -427,6 +429,7 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
with open(local_path, "rb") as f:
media = f.read()
artifacts[Path(local_path).name] = media
display({MimeType.TEXT_ARTIFACT_NAME: Path(local_path).name}, raw=True)
return f"[Media {Path(local_path).name} saved]"


Expand Down Expand Up @@ -623,6 +626,8 @@ def use_object_detection_fine_tuning(

diff = get_diff_with_prompts(name, code, new_code)
print(diff)

display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return diff


Expand Down
22 changes: 21 additions & 1 deletion vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class MimeType(str, Enum):
TEXT_LATEX = "text/latex"
APPLICATION_JSON = "application/json"
APPLICATION_JAVASCRIPT = "application/javascript"
TEXT_ARTIFACT_NAME = "text/artifact/name"


class FileSerializer:
Expand Down Expand Up @@ -103,6 +104,7 @@ class Result:
latex: Optional[str] = None
json: Optional[Dict[str, Any]] = None
javascript: Optional[str] = None
artifact_name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
"Extra data that can be included. Not part of the standard types."

Expand All @@ -127,6 +129,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]):
self.latex = data.pop(MimeType.TEXT_LATEX, None)
self.json = data.pop(MimeType.APPLICATION_JSON, None)
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
self.artifact_name = data.pop(MimeType.TEXT_ARTIFACT_NAME, None)
self.extra = data
# Only keeping the PNG representation if both PNG and JPEG are present
if self.png and self.jpeg:
Expand Down Expand Up @@ -204,6 +207,8 @@ def formats(self) -> Iterable[str]:
formats.append("javascript")
if self.mp4:
formats.append("mp4")
if self.artifact_name:
formats.append("artifact_name")
if self.extra:
formats.extend(iter(self.extra))
return formats
Expand Down Expand Up @@ -691,8 +696,9 @@ def new_instance(
if not code_sandbox_runtime:
code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
if code_sandbox_runtime == "e2b":
envs = _get_e2b_env()
instance: CodeInterpreter = E2BCodeInterpreter(
timeout=_SESSION_TIMEOUT, remote_path=remote_path
timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs
)
elif code_sandbox_runtime == "local":
instance = LocalCodeInterpreter(
Expand All @@ -705,6 +711,20 @@ def new_instance(
return instance


def _get_e2b_env() -> Union[Dict[str, str], None]:
openai_api_key = os.getenv("OPENAI_API_KEY", "")
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
if openai_api_key or anthropic_api_key:
envs = {}
if openai_api_key:
envs["OPENAI_API_KEY"] = openai_api_key
if anthropic_api_key:
envs["ANTHROPIC_API_KEY"] = anthropic_api_key
else:
envs = None
return envs


def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution:
"""Parse notebook cell outputs to Execution object. Output types:
https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs
Expand Down

0 comments on commit 5a8f6a9

Please sign in to comment.