Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verbosity levels & improved JSON parsing #83

Merged
merged 2 commits into from
May 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ def extract_code(code: str) -> str:
return code


def extract_json(json_str: str) -> Dict[str, Any]:
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError:
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
json_dict = json.loads(json_str)
return json_dict # type: ignore


def write_plan(
chat: List[Dict[str, str]],
plan: Optional[List[Dict[str, Any]]],
Expand All @@ -65,8 +80,8 @@ def write_plan(
context = USER_REQ_CONTEXT.format(user_requirement=user_requirements)
prompt = PLAN.format(context=context, plan=str(plan), tool_desc=tool_desc)
chat[-1]["content"] = prompt
plan = json.loads(model.chat(chat).replace("```", "").strip())
return plan["user_req"], plan["plan"] # type: ignore
new_plan = extract_json(model.chat(chat))
return new_plan["user_req"], new_plan["plan"]


def write_code(
Expand Down Expand Up @@ -133,7 +148,7 @@ def debug_code(
{"role": "system", "content": DEBUG_SYS_MSG},
{"role": "user", "content": prompt},
]
code_and_ref = json.loads(model.chat(messages).replace("```", "").strip())
code_and_ref = extract_json(model.chat(messages))
if hasattr(model, "kwargs"):
del model.kwargs["response_format"]
return extract_code(code_and_ref["improved_impl"]), code_and_ref["reflection"]
Expand All @@ -149,7 +164,7 @@ def write_and_exec_code(
exec: Execute,
retrieved_ltm: str,
max_retry: int = 3,
verbose: bool = False,
verbosity: int = 0,
) -> Tuple[bool, str, str, Dict[str, List[str]]]:
success = False
counter = 0
Expand All @@ -159,6 +174,9 @@ def write_and_exec_code(
user_req, subtask, retrieved_ltm, tool_info, orig_code, model
)
success, result = exec.run_isolation(code)
if verbosity == 2:
_CONSOLE.print(Syntax(code, "python", theme="gruvbox-dark", line_numbers=True))
_LOGGER.info(f"\tCode success: {success}, result: {str(result)}")
working_memory: Dict[str, List[str]] = {}
while not success and counter < max_retry:
if subtask not in working_memory:
Expand All @@ -180,11 +198,11 @@ def write_and_exec_code(
)
success, result = exec.run_isolation(code)
counter += 1
if verbose:
if verbosity == 2:
_CONSOLE.print(
Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)
)
_LOGGER.info(f"\tDebugging reflection, result: {reflection}, {result}")
_LOGGER.info(f"\tDebugging reflection: {reflection}, result: {result}")

if success:
working_memory[subtask].append(
Expand All @@ -204,7 +222,7 @@ def run_plan(
code: str,
tool_recommender: Sim,
long_term_memory: Optional[Sim] = None,
verbose: bool = False,
verbosity: int = 0,
) -> Tuple[str, str, List[Dict[str, Any]], Dict[str, List[str]]]:
active_plan = [e for e in plan if "success" not in e or not e["success"]]
current_code = code
Expand Down Expand Up @@ -235,7 +253,7 @@ def run_plan(
tool_info,
exec,
retrieved_ltm,
verbose=verbose,
verbosity=verbosity,
)
if task["type"] == "code":
current_code = code
Expand All @@ -244,11 +262,11 @@ def run_plan(

working_memory.update(working_memory_i)

if verbose:
if verbosity == 1:
_CONSOLE.print(
Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)
)
_LOGGER.info(f"\tCode success, result: {success}, {str(result)}")
_LOGGER.info(f"\tCode success: {success} result: {str(result)}")

task["success"] = success
task["result"] = result
Expand Down Expand Up @@ -283,23 +301,23 @@ def __init__(
timeout: int = 600,
tool_recommender: Optional[Sim] = None,
long_term_memory: Optional[Sim] = None,
verbose: bool = False,
verbosity: int = 0,
) -> None:
self.planner = OpenAILLM(temperature=0.1, json_mode=True)
self.coder = OpenAILLM(temperature=0.1)
self.planner = OpenAILLM(temperature=0.0, json_mode=True)
self.coder = OpenAILLM(temperature=0.0)
self.exec = Execute(timeout=timeout)
if tool_recommender is None:
self.tool_recommender = Sim(TOOLS_DF, sim_key="desc")
else:
self.tool_recommender = tool_recommender
self.verbose = verbose
self.verbosity = verbosity
self._working_memory: Dict[str, List[str]] = {}
if long_term_memory is not None:
if "doc" not in long_term_memory.df.columns:
raise ValueError("Long term memory must have a 'doc' column.")
self.long_term_memory = long_term_memory
self.max_retries = 3
if self.verbose:
if self.verbosity:
_LOGGER.setLevel(logging.INFO)

def __call__(
Expand Down Expand Up @@ -355,7 +373,7 @@ def chat_with_workflow(
working_code,
self.tool_recommender,
self.long_term_memory,
self.verbose,
self.verbosity,
)
success = all(task["success"] for task in plan)
working_memory.update(working_memory_i)
Expand Down
Loading