Skip to content

Commit

Permalink
Minor Improvements to performance (#85)
Browse files Browse the repository at this point in the history
* added different verbosity levels, better json parsing

* fix typing error

* fixed issues with agent coder

* add save json

* add thresh to top k

* fix bug, add thresh for top k tools

* update prompts

* black and isort

* fix type errors

* added thresh doc
  • Loading branch information
dillonalaird authored May 16, 2024
1 parent 3630983 commit e64083d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 12 deletions.
24 changes: 19 additions & 5 deletions vision_agent/agent/agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from pathlib import Path
from typing import Dict, List, Optional, Union

from rich.console import Console
from rich.syntax import Syntax

from vision_agent.agent import Agent
from vision_agent.agent.agent_coder_prompts import (
DEBUG,
Expand Down Expand Up @@ -40,6 +43,7 @@
logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_EXECUTE = Execute()
_CONSOLE = Console()


def write_tests(question: str, code: str, model: LLM) -> str:
Expand Down Expand Up @@ -103,7 +107,7 @@ def run_visual_tests(


def fix_bugs(code: str, tests: str, result: str, feedback: str, model: LLM) -> str:
prompt = FIX_BUG.format(completion=code, test_case=tests, result=result)
prompt = FIX_BUG.format(code=code, tests=tests, result=result, feedback=feedback)
completion = model(prompt)
return preprocess_data(completion)

Expand Down Expand Up @@ -139,7 +143,8 @@ def __init__(
else visual_tester_agent
)
self.max_turns = 3
if verbose:
self.verbose = verbose
if self.verbose:
_LOGGER.setLevel(logging.INFO)

def __call__(
Expand All @@ -164,9 +169,15 @@ def chat(
feedback = ""
for _ in range(self.max_turns):
code = write_program(question, feedback, self.coder_agent)
_LOGGER.info(f"code:\n{code}")
if self.verbose:
_CONSOLE.print(
Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)
)
debug = write_debug(question, code, feedback, self.tester_agent)
_LOGGER.info(f"debug:\n{debug}")
if self.verbose:
_CONSOLE.print(
Syntax(debug, "python", theme="gruvbox-dark", line_numbers=True)
)
results = execute_tests(code, debug)
_LOGGER.info(
f"execution results: passed: {results['passed']}\n{results['result']}"
Expand All @@ -176,7 +187,10 @@ def chat(
code = fix_bugs(
code, debug, results["result"].strip(), feedback, self.coder_agent # type: ignore
)
_LOGGER.info(f"fixed code:\n{code}")
if self.verbose:
_CONSOLE.print(
Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)
)
else:
# TODO: Sometimes it prints nothing, so we need to handle that case
# TODO: The visual agent reflection does not work very well, needs more testing
Expand Down
7 changes: 5 additions & 2 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def run_plan(
f"""
{tabulate(tabular_data=[task], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
tools = tool_recommender.top_k(task["instruction"])
tools = tool_recommender.top_k(task["instruction"], thresh=0.3)
tool_info = "\n".join([e["doc"] for e in tools])

if verbosity == 2:
Expand Down Expand Up @@ -288,6 +288,7 @@ class VisionAgentV2(Agent):
solve vision tasks. It is inspired by MetaGPT's Data Interpreter
https://arxiv.org/abs/2402.18679. Vision Agent has several key features to help it
generate code:
- A planner to generate a plan of tasks to solve a user requirement. The planner
can output code tasks or test tasks, where test tasks are used to verify the code.
- Automatic debugging, if a task fails, the agent will attempt to debug the code
Expand Down Expand Up @@ -381,7 +382,9 @@ def chat_with_workflow(
self.long_term_memory,
self.verbosity,
)
success = all(task["success"] for task in plan)
success = all(
task["success"] if "success" in task else False for task in plan
)
working_memory.update(working_memory_i)

if not success:
Expand Down
7 changes: 4 additions & 3 deletions vision_agent/agent/vision_agent_v2_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# Task:
Based on the context and the tools you have available, write a plan of subtasks to achieve the user request that adhere to the following requirements:
- For each subtask, you should provide a short instruction on what to do. Ensure the subtasks are large enough to be meaningful, encompassing multiple lines of code.
- For each subtask, you should provide instructions on what to do. Write detailed subtasks, ensure they are large enough to be meaningful, encompassing multiple lines of code.
- You do not need to have the agent rewrite any tool functionality you already have, you should instead instruct it to utilize one or more of those tools in each subtask.
- You can have agents either write coding tasks, to code some functionality or testing tasks to test previous functionality.
- If a current plan exists, examine each item in the plan to determine if it was successful. If there was an item that failed, i.e. 'success': False, then you should rewrite that item and all subsequent items to ensure that the rewritten plan is successful.
Expand Down Expand Up @@ -73,9 +73,10 @@
{code}
# Constraints
- Write a function that accomplishes the 'User Requirement'. You are supplied code from a previous task under 'Previous Code', feel free to copy over that code into your own implementation if you need it.
- Always prioritize using pre-defined tools or code for the same functionality from 'Tool Info for Current Subtask'. You have access to all these tools through the `from vision_agent.tools.tools_v2 import *` import.
- Write a function that accomplishes the 'Current Subtask'. You are supplied code from a previous task under 'Previous Code', do not delete or change previous code unless it contains a bug or it is necessary to complete the 'Current Subtask'.
- Always prioritize using pre-defined tools or code for the same functionality from 'Tool Info' when working on 'Current Subtask'. You have access to all these tools through the `from vision_agent.tools.tools_v2 import *` import.
- You may recieve previous trials and errors under 'Previous Task', this is code, output and reflections from previous tasks. You can use these to avoid running in to the same issues when writing your code.
- Use the `save_json` function from `vision_agent.tools.tools_v2` to save your output as a json file.
- Write clean, readable, and well-documented code.
# Output
Expand Down
29 changes: 28 additions & 1 deletion vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import io
import json
import logging
import tempfile
from importlib import resources
Expand Down Expand Up @@ -285,6 +286,31 @@ def closest_box_distance(box1: List[float], box2: List[float]) -> float:
# Utility and visualization functions


def save_json(data: Any, file_path: str) -> None:
"""'save_json' is a utility function that saves data as a JSON file. It is helpful
for saving data that contains NumPy arrays which are not JSON serializable.
Parameters:
data (Any): The data to save.
file_path (str): The path to save the JSON file.
Example
-------
>>> save_json(data, "path/to/file.json")
"""

class NumpyEncoder(json.JSONEncoder):
def default(self, obj: Any): # type: ignore
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.bool_):
return bool(obj)
return json.JSONEncoder.default(self, obj)

with open(file_path, "w") as f:
json.dump(data, f, cls=NumpyEncoder)


def load_image(image_path: str) -> np.ndarray:
"""'load_image' is a utility function that loads an image from the given path.
Expand Down Expand Up @@ -480,6 +506,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
ocr,
closest_mask_distance,
closest_box_distance,
save_json,
load_image,
save_image,
overlay_bounding_boxes,
Expand All @@ -489,5 +516,5 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
UTILITIES_DOCSTRING = get_tool_documentation(
[load_image, save_image, overlay_bounding_boxes]
[save_json, load_image, save_image, overlay_bounding_boxes]
)
7 changes: 6 additions & 1 deletion vision_agent/utils/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,15 @@ def save(self, sim_file: Union[str, Path]) -> None:
df = df.drop("embs", axis=1)
df.to_csv(sim_file / "df.csv", index=False)

def top_k(self, query: str, k: int = 5) -> Sequence[Dict]:
def top_k(
self, query: str, k: int = 5, thresh: Optional[float] = None
) -> Sequence[Dict]:
"""Returns the top k most similar items to the query.
Parameters:
query: str: The query to compare to.
k: int: The number of items to return.
thresh: Optional[float]: The minimum similarity threshold.
Returns:
Sequence[Dict]: The top k most similar items.
Expand All @@ -70,6 +73,8 @@ def top_k(self, query: str, k: int = 5) -> Sequence[Dict]:
embedding = get_embedding(self.client, query, model=self.model)
self.df["sim"] = self.df.embs.apply(lambda x: 1 - cosine(x, embedding))
res = self.df.sort_values("sim", ascending=False).head(k)
if thresh is not None:
res = res[res.sim > thresh]
return res[[c for c in res.columns if c != "embs"]].to_dict(orient="records")


Expand Down

0 comments on commit e64083d

Please sign in to comment.