diff --git a/vision_agent/agent/agent.py b/vision_agent/agent/agent.py index 6b11f297..ca2cf181 100644 --- a/vision_agent/agent/agent.py +++ b/vision_agent/agent/agent.py @@ -11,7 +11,7 @@ def __call__( self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None, - ) -> str: + ) -> Union[str, List[Message]]: pass @abstractmethod diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 17fe347d..6399016e 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -38,7 +38,7 @@ class BoilerplateCode: ] @staticmethod - def add_boilerplate(code: str, **format) -> str: + def add_boilerplate(code: str, **format: Any) -> str: """Run this method to prepend the default imports to the code. NOTE: be sure to run this method after the custom tools have been registered. """ @@ -131,10 +131,13 @@ def __init__( self.code_sandbox_runtime = code_sandbox_runtime if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) - self.local_artifacts_path = ( - Path(local_artifacts_path) - if local_artifacts_path is not None - else "artifacts.pkl" + self.local_artifacts_path = cast( + str, + ( + Path(local_artifacts_path) + if local_artifacts_path is not None + else "artifacts.pkl" + ), ) def __call__( @@ -160,7 +163,7 @@ def __call__( if media is not None: input[0]["media"] = [media] results = self.chat_with_code(input, artifacts) - return results # type: ignore + return results def chat_with_code( self, diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 7856bdb8..cc0711b6 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -725,7 +725,7 @@ def chat_with_workflow( else code_interpreter.upload_file(media) ) chat_i["content"] += f" Media name {media}" # type: ignore - media_list.append(media) + media_list.append(str(media)) int_chat = cast( List[Message], diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index bc3e3058..89c2dbdd 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -46,7 +46,7 @@ class Artifacts: def __init__(self, remote_save_path: Union[str, Path]) -> None: self.remote_save_path = Path(remote_save_path) - self.artifacts = {} + self.artifacts: Dict[str, Any] = {} self.code_sandbox_runtime = None @@ -81,10 +81,10 @@ def save(self, local_path: Optional[Union[str, Path]] = None) -> None: with open(save_path, "wb") as f: pkl.dump(self.artifacts, f) - def __iter__(self): + def __iter__(self) -> Any: return iter(self.artifacts) - def __getitem__(self, name: str) -> str: + def __getitem__(self, name: str) -> Any: return self.artifacts[name] def __setitem__(self, name: str, value: str) -> None: @@ -201,7 +201,7 @@ def edit_artifact( cur_line = start + len(content.split("\n")) // 2 with tempfile.NamedTemporaryFile(delete=True) as f: - with open(f.name, "w") as f: + with open(f.name, "w") as f: # type: ignore f.writelines(edited_lines) process = subprocess.Popen( diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 08924875..15e4f9b9 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -410,7 +410,7 @@ def upload_file(self, file: Union[str, Path]) -> Path: # Default behavior is a no-op (for local code interpreter) return Path(file) - def download_file(self, remote_file_path: str, local_file_path: str) -> Path: + def download_file(self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]) -> Path: # Default behavior is a no-op (for local code interpreter) return Path(local_file_path) @@ -528,9 +528,9 @@ def upload_file(self, file: Union[str, Path]) -> Path: _LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}") return self.remote_path - def download_file(self, remote_file_path: str, local_file_path: str) -> Path: + def download_file(self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]) -> Path: with open(local_file_path, "w+b") as f: - f.write(self.interpreter.files.read(path=remote_file_path, format="bytes")) + f.write(self.interpreter.files.read(path=str(remote_file_path), format="bytes")) _LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}") return Path(local_file_path) @@ -616,7 +616,7 @@ def exec_cell(self, code: str) -> Execution: traceback_raw = traceback.format_exc().splitlines() return Execution.from_exception(e, traceback_raw) - def upload_file(self, file_path: str) -> Path: + def upload_file(self, file_path: Union[str, Path]) -> Path: with open(file_path, "rb") as f: contents = f.read() with open(self.remote_path / Path(file_path).name, "wb") as f: @@ -625,7 +625,7 @@ def upload_file(self, file_path: str) -> Path: return Path(self.remote_path / file_path) - def download_file(self, remote_file_path: str, local_file_path: str) -> Path: + def download_file(self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]) -> Path: with open(self.remote_path / remote_file_path, "rb") as f: contents = f.read() with open(local_file_path, "wb") as f: