Skip to content

Commit

Permalink
fixed type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 9, 2024
1 parent f0a2b95 commit ec93b88
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
dir=WORKSPACE,
conversation=conversation,
)
return extract_json(orch([{"role": "user", "content": prompt}]))
return extract_json(orch([{"role": "user", "content": prompt}], stream=False)) # type: ignore


def run_code_action(code: str, code_interpreter: CodeInterpreter) -> str:
Expand Down
17 changes: 9 additions & 8 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def write_plans(
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
chat[-1]["content"] = prompt
return extract_json(model.chat(chat))
return extract_json(model(chat, stream=False)) # type: ignore


def pick_plan(
Expand Down Expand Up @@ -160,7 +160,7 @@ def pick_plan(
docstring=tool_info, plans=plan_str, previous_attempts="", media=media
)

code = extract_code(model(prompt))
code = extract_code(model(prompt, stream=False)) # type: ignore
log_progress(
{
"type": "log",
Expand Down Expand Up @@ -211,7 +211,7 @@ def pick_plan(
"code": DefaultImports.prepend_imports(code),
}
)
code = extract_code(model(prompt))
code = extract_code(model(prompt, stream=False)) # type: ignore
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
Expand Down Expand Up @@ -251,7 +251,7 @@ def pick_plan(
tool_output=tool_output_str[:20_000],
)
chat[-1]["content"] = prompt
best_plan = extract_json(model(chat))
best_plan = extract_json(model(chat, stream=False)) # type: ignore

if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
Expand Down Expand Up @@ -286,7 +286,7 @@ def write_code(
feedback=feedback,
)
chat[-1]["content"] = prompt
return extract_code(coder(chat))
return extract_code(coder(chat, stream=False)) # type: ignore


def write_test(
Expand All @@ -310,7 +310,7 @@ def write_test(
media=media,
)
chat[-1]["content"] = prompt
return extract_code(tester(chat))
return extract_code(tester(chat, stream=False)) # type: ignore


def write_and_test_code(
Expand Down Expand Up @@ -439,13 +439,14 @@ def debug_code(
while not success and count < 3:
try:
fixed_code_and_test = extract_json(
debugger(
debugger( # type: ignore
FIX_BUG.format(
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
feedback=format_memory(working_memory + new_working_memory),
)
),
stream=False,
)
)
success = True
Expand Down
18 changes: 10 additions & 8 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union, cast
from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast

import anthropic
import requests
Expand Down Expand Up @@ -58,22 +58,24 @@ def encode_media(media: Union[str, Path]) -> str:
class LMM(ABC):
@abstractmethod
def generate(
self, prompt: str, media: Optional[List[Union[str, Path]]] = None
) -> str:
self, prompt: str, media: Optional[List[Union[str, Path]]] = None, **kwargs: Any
) -> Union[str, Iterator[Optional[str]]]:
pass

@abstractmethod
def chat(
self,
chat: List[Message],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
pass

@abstractmethod
def __call__(
self,
input: Union[str, List[Message]],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
pass


Expand Down Expand Up @@ -150,7 +152,7 @@ def chat(
)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
for chunk in response:
chunk_message = chunk.choices[0].delta.content
chunk_message = chunk.choices[0].delta.content # type: ignore
yield chunk_message
else:
return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -189,8 +191,8 @@ def generate(
)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
for chunk in response:
chunk_message = chunk.choices[0].delta.content
yield chunk_message # type: ignore
chunk_message = chunk.choices[0].delta.content # type: ignore
yield chunk_message
else:
return cast(str, response.choices[0].message.content)

Expand Down

0 comments on commit ec93b88

Please sign in to comment.