Skip to content

Commit

Permalink
added ClaudeVisionAgentCoder and fixed json parser
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 9, 2024
1 parent a7dd110 commit 04161d2
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 35 deletions.
48 changes: 30 additions & 18 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,48 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
if match:
json_str = match.group()
try:
# remove trailing comma
trailing_bracket_pattern = r",\s+\}"
json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL)

json_dict = json.loads(json_str)
return json_dict # type: ignore
except json.JSONDecodeError:
return None
return None


def _find_markdown_json(json_str: str) -> str:
pattern = r"```json(.*?)```"
match = re.search(pattern, json_str, re.DOTALL)
if match:
return match.group(1).strip()
return json_str


def _strip_markdown_code(inp_str: str) -> str:
pattern = r"```python.*?```"
cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL)
return cleaned_str


def extract_json(json_str: str) -> Dict[str, Any]:
json_str = json_str.replace("\n", " ").strip()

try:
json_str = json_str.replace("\n", " ")
json_dict = json.loads(json_str)
return json.loads(json_str)
except json.JSONDecodeError:
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
json_dict = _extract_sub_json(json_str)
if json_dict is not None:
return json_dict # type: ignore
error_msg = f"Could not extract JSON from the given str: {json_str}"
json_orig = json_str
json_str = _strip_markdown_code(json_str)
json_str = _find_markdown_json(json_str)
json_dict = _extract_sub_json(json_str)

if json_dict is None:
error_msg = f"Could not extract JSON from the given str: {json_orig}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e
raise ValueError(error_msg)

return json_dict # type: ignore
return json_dict # type: ignore


def extract_code(code: str) -> str:
Expand Down
89 changes: 72 additions & 17 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
TEST_PLANS,
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from vision_agent.lmm import (
LMM,
AzureOpenAILMM,
ClaudeSonnetLMM,
Message,
OllamaLMM,
OpenAILMM,
)
from vision_agent.tools.meta_tools import get_diff
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
Expand Down Expand Up @@ -168,8 +175,8 @@ def pick_plan(
)
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]
if len(tool_output.text().strip()) > 0:
tool_output_str = tool_output.text().strip()

if verbosity == 2:
_print_code("Initial code and tests:", code)
Expand Down Expand Up @@ -229,7 +236,7 @@ def pick_plan(

if verbosity == 2:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempt {count}")
_LOGGER.info(f"Code execution result after attempt {count + 1}")

count += 1

Expand Down Expand Up @@ -387,7 +394,6 @@ def write_and_test_code(
"code": DefaultImports.prepend_imports(code),
"payload": {
"test": test,
# "result": result.to_json(),
},
}
)
Expand Down Expand Up @@ -451,17 +457,32 @@ def debug_code(
count = 0
while not success and count < 3:
try:
fixed_code_and_test = extract_json(
debugger( # type: ignore
FIX_BUG.format(
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
feedback=format_memory(working_memory + new_working_memory),
),
stream=False,
)
# LLMs write worse code when it's in JSON, so we have it write JSON
# followed by code each wrapped in markdown blocks.
fixed_code_and_test_str = debugger( # type: ignore
FIX_BUG.format(
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
feedback=format_memory(working_memory + new_working_memory),
),
stream=False,
)
fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
fixed_code_and_test = extract_json(fixed_code_and_test_str)
code = extract_code(fixed_code_and_test_str)
if (
"which_code" in fixed_code_and_test
and fixed_code_and_test["which_code"] == "test"
):
fixed_code_and_test["code"] = ""
fixed_code_and_test["test"] = code
else: # for everything else always assume it's updating code
fixed_code_and_test["code"] = code
fixed_code_and_test["test"] = ""
if "which_code" in fixed_code_and_test:
del fixed_code_and_test["which_code"]

success = True
except Exception as e:
_LOGGER.exception(f"Error while extracting JSON: {e}")
Expand All @@ -472,9 +493,9 @@ def debug_code(
old_test = test

if fixed_code_and_test["code"].strip() != "":
code = extract_code(fixed_code_and_test["code"])
code = fixed_code_and_test["code"]
if fixed_code_and_test["test"].strip() != "":
test = extract_code(fixed_code_and_test["test"])
test = fixed_code_and_test["test"]

new_working_memory.append(
{
Expand Down Expand Up @@ -876,6 +897,40 @@ def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
)


class ClaudeVisionAgentCoder(VisionAgentCoder):
def __init__(
self,
planner: Optional[LMM] = None,
coder: Optional[LMM] = None,
tester: Optional[LMM] = None,
debugger: Optional[LMM] = None,
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
) -> None:
# NOTE: Claude doesn't have an official JSON mode
self.planner = ClaudeSonnetLMM(temperature=0.0) if planner is None else planner
self.coder = ClaudeSonnetLMM(temperature=0.0) if coder is None else coder
self.tester = ClaudeSonnetLMM(temperature=0.0) if tester is None else tester
self.debugger = (
ClaudeSonnetLMM(temperature=0.0) if debugger is None else debugger
)
self.verbosity = verbosity
if self.verbosity > 0:
_LOGGER.setLevel(logging.INFO)

# Anthropic does not offer any embedding models and instead recomends Voyage,
# we're using OpenAI's embedder for now.
self.tool_recommender = (
Sim(T.TOOLS_DF, sim_key="desc")
if tool_recommender is None
else tool_recommender
)
self.report_progress_callback = report_progress_callback
self.code_sandbox_runtime = code_sandbox_runtime


class OllamaVisionAgentCoder(VisionAgentCoder):
"""VisionAgentCoder that uses Ollama models for planning, coding, testing.
Expand Down

0 comments on commit 04161d2

Please sign in to comment.