Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 29, 2024
1 parent 3e7cfd2 commit afc87c0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion vision_agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __call__(
self,
input: Union[str, List[Message]],
media: Optional[Union[str, Path]] = None,
) -> str:
) -> Union[str, List[Message]]:
pass

@abstractmethod
Expand Down
15 changes: 9 additions & 6 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BoilerplateCode:
]

@staticmethod
def add_boilerplate(code: str, **format) -> str:
def add_boilerplate(code: str, **format: Any) -> str:
"""Run this method to prepend the default imports to the code.
NOTE: be sure to run this method after the custom tools have been registered.
"""
Expand Down Expand Up @@ -131,10 +131,13 @@ def __init__(
self.code_sandbox_runtime = code_sandbox_runtime
if self.verbosity >= 1:
_LOGGER.setLevel(logging.INFO)
self.local_artifacts_path = (
Path(local_artifacts_path)
if local_artifacts_path is not None
else "artifacts.pkl"
self.local_artifacts_path = cast(
str,
(
Path(local_artifacts_path)
if local_artifacts_path is not None
else "artifacts.pkl"
),
)

def __call__(
Expand All @@ -160,7 +163,7 @@ def __call__(
if media is not None:
input[0]["media"] = [media]
results = self.chat_with_code(input, artifacts)
return results # type: ignore
return results

def chat_with_code(
self,
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def chat_with_workflow(
else code_interpreter.upload_file(media)
)
chat_i["content"] += f" Media name {media}" # type: ignore
media_list.append(media)
media_list.append(str(media))

int_chat = cast(
List[Message],
Expand Down
8 changes: 4 additions & 4 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Artifacts:

def __init__(self, remote_save_path: Union[str, Path]) -> None:
self.remote_save_path = Path(remote_save_path)
self.artifacts = {}
self.artifacts: Dict[str, Any] = {}

self.code_sandbox_runtime = None

Expand Down Expand Up @@ -81,10 +81,10 @@ def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
with open(save_path, "wb") as f:
pkl.dump(self.artifacts, f)

def __iter__(self):
def __iter__(self) -> Any:
return iter(self.artifacts)

def __getitem__(self, name: str) -> str:
def __getitem__(self, name: str) -> Any:
return self.artifacts[name]

def __setitem__(self, name: str, value: str) -> None:
Expand Down Expand Up @@ -201,7 +201,7 @@ def edit_artifact(

cur_line = start + len(content.split("\n")) // 2
with tempfile.NamedTemporaryFile(delete=True) as f:
with open(f.name, "w") as f:
with open(f.name, "w") as f: # type: ignore
f.writelines(edited_lines)

process = subprocess.Popen(
Expand Down
10 changes: 5 additions & 5 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def upload_file(self, file: Union[str, Path]) -> Path:
# Default behavior is a no-op (for local code interpreter)
return Path(file)

def download_file(self, remote_file_path: str, local_file_path: str) -> Path:
def download_file(self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]) -> Path:
# Default behavior is a no-op (for local code interpreter)
return Path(local_file_path)

Expand Down Expand Up @@ -528,9 +528,9 @@ def upload_file(self, file: Union[str, Path]) -> Path:
_LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}")
return self.remote_path

def download_file(self, remote_file_path: str, local_file_path: str) -> Path:
def download_file(self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]) -> Path:
with open(local_file_path, "w+b") as f:
f.write(self.interpreter.files.read(path=remote_file_path, format="bytes"))
f.write(self.interpreter.files.read(path=str(remote_file_path), format="bytes"))
_LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}")
return Path(local_file_path)

Expand Down Expand Up @@ -616,7 +616,7 @@ def exec_cell(self, code: str) -> Execution:
traceback_raw = traceback.format_exc().splitlines()
return Execution.from_exception(e, traceback_raw)

def upload_file(self, file_path: str) -> Path:
def upload_file(self, file_path: Union[str, Path]) -> Path:
with open(file_path, "rb") as f:
contents = f.read()
with open(self.remote_path / Path(file_path).name, "wb") as f:
Expand All @@ -625,7 +625,7 @@ def upload_file(self, file_path: str) -> Path:

return Path(self.remote_path / file_path)

def download_file(self, remote_file_path: str, local_file_path: str) -> Path:
def download_file(self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]) -> Path:
with open(self.remote_path / remote_file_path, "rb") as f:
contents = f.read()
with open(local_file_path, "wb") as f:
Expand Down

0 comments on commit afc87c0

Please sign in to comment.