From ec93b8832441a28547ca98c8736e676f62b4bc16 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 9 Aug 2024 08:01:54 -0700 Subject: [PATCH] fixed type errors --- vision_agent/agent/vision_agent.py | 2 +- vision_agent/agent/vision_agent_coder.py | 17 +++++++++-------- vision_agent/lmm/lmm.py | 18 ++++++++++-------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 9090b706..a39fe208 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -63,7 +63,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]: dir=WORKSPACE, conversation=conversation, ) - return extract_json(orch([{"role": "user", "content": prompt}])) + return extract_json(orch([{"role": "user", "content": prompt}], stream=False)) # type: ignore def run_code_action(code: str, code_interpreter: CodeInterpreter) -> str: diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 4ef6b07e..5a7f9a2e 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -129,7 +129,7 @@ def write_plans( context = USER_REQ.format(user_request=user_request) prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) chat[-1]["content"] = prompt - return extract_json(model.chat(chat)) + return extract_json(model(chat, stream=False)) # type: ignore def pick_plan( @@ -160,7 +160,7 @@ def pick_plan( docstring=tool_info, plans=plan_str, previous_attempts="", media=media ) - code = extract_code(model(prompt)) + code = extract_code(model(prompt, stream=False)) # type: ignore log_progress( { "type": "log", @@ -211,7 +211,7 @@ def pick_plan( "code": DefaultImports.prepend_imports(code), } ) - code = extract_code(model(prompt)) + code = extract_code(model(prompt, stream=False)) # type: ignore tool_output = code_interpreter.exec_isolation( DefaultImports.prepend_imports(code) ) @@ -251,7 +251,7 @@ def pick_plan( tool_output=tool_output_str[:20_000], ) chat[-1]["content"] = prompt - best_plan = extract_json(model(chat)) + best_plan = extract_json(model(chat, stream=False)) # type: ignore if verbosity >= 1: _LOGGER.info(f"Best plan:\n{best_plan}") @@ -286,7 +286,7 @@ def write_code( feedback=feedback, ) chat[-1]["content"] = prompt - return extract_code(coder(chat)) + return extract_code(coder(chat, stream=False)) # type: ignore def write_test( @@ -310,7 +310,7 @@ def write_test( media=media, ) chat[-1]["content"] = prompt - return extract_code(tester(chat)) + return extract_code(tester(chat, stream=False)) # type: ignore def write_and_test_code( @@ -439,13 +439,14 @@ def debug_code( while not success and count < 3: try: fixed_code_and_test = extract_json( - debugger( + debugger( # type: ignore FIX_BUG.format( code=code, tests=test, result="\n".join(result.text().splitlines()[-50:]), feedback=format_memory(working_memory + new_working_memory), - ) + ), + stream=False, ) ) success = True diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index ba35829a..1f49f8ba 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -5,7 +5,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union, cast +from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast import anthropic import requests @@ -58,22 +58,24 @@ def encode_media(media: Union[str, Path]) -> str: class LMM(ABC): @abstractmethod def generate( - self, prompt: str, media: Optional[List[Union[str, Path]]] = None - ) -> str: + self, prompt: str, media: Optional[List[Union[str, Path]]] = None, **kwargs: Any + ) -> Union[str, Iterator[Optional[str]]]: pass @abstractmethod def chat( self, chat: List[Message], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: pass @abstractmethod def __call__( self, input: Union[str, List[Message]], - ) -> str: + **kwargs: Any, + ) -> Union[str, Iterator[Optional[str]]]: pass @@ -150,7 +152,7 @@ def chat( ) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: for chunk in response: - chunk_message = chunk.choices[0].delta.content + chunk_message = chunk.choices[0].delta.content # type: ignore yield chunk_message else: return cast(str, response.choices[0].message.content) @@ -189,8 +191,8 @@ def generate( ) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: for chunk in response: - chunk_message = chunk.choices[0].delta.content - yield chunk_message # type: ignore + chunk_message = chunk.choices[0].delta.content # type: ignore + yield chunk_message else: return cast(str, response.choices[0].message.content)