Skip to content

Commit

Permalink
Support streaming chat logs of an agent (#47)
Browse files Browse the repository at this point in the history
Add a callback for reporting the chat progress of an agent
  • Loading branch information
humpydonkey authored and dillonalaird committed Apr 11, 2024
1 parent c27118c commit e8ec297
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 82 deletions.
7 changes: 7 additions & 0 deletions vision_agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ def __call__(
image: Optional[Union[str, Path]] = None,
) -> str:
pass

@abstractmethod
def log_progress(self, description: str) -> None:
"""Log the progress of the agent.
This is a hook that is intended for reporting the progress of the agent.
"""
pass
194 changes: 113 additions & 81 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,79 +244,6 @@ def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any:
return str(e)


def retrieval(
model: Union[LLM, LMM, Agent],
question: str,
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}, reflections
)
if tool_id is None:
return {}, ""

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
tool_name = tool_instructions["name"]

parameters = choose_parameter(
model, question, tool_usage, previous_log, reflections
)
if parameters is None:
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}

_LOGGER.info(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
if isinstance(result["parameters"], Dict):
call_results.append(
function_call(tools[tool_id]["class"], result["parameters"])
)
elif isinstance(result["parameters"], List):
for parameters in result["parameters"]:
call_results.append(function_call(tools[tool_id]["class"], parameters))
return call_results

call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = str(call_results)
# _LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str


def create_tasks(
task_model: Union[LLM, LMM], question: str, tools: Dict[int, Any], reflections: str
) -> List[Dict]:
tasks = task_decompose(
task_model,
question,
{k: v["description"] for k, v in tools.items()},
reflections,
)
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(task_model, question, task_list)
try:
task_list = topological_sort(task_list)
except Exception:
_LOGGER.error(f"Failed topological_sort on: {task_list}")
else:
task_list = []
_LOGGER.info(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
)
return task_list


def self_reflect(
reflect_model: Union[LLM, LMM],
question: str,
Expand Down Expand Up @@ -350,7 +277,7 @@ def parse_reflect(reflect: str) -> bool:
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]:
continue

parameters = tool_result["parameters"]
Expand Down Expand Up @@ -420,7 +347,18 @@ def __init__(
reflect_model: Optional[Union[LLM, LMM]] = None,
max_retries: int = 2,
verbose: bool = False,
report_progress_callback: Optional[Callable[[str], None]] = None,
):
"""VisionAgent constructor.
Parameters
task_model: the model to use for task decomposition.
answer_model: the model to use for reasoning and concluding the answer.
reflect_model: the model to use for self reflection.
max_retries: maximum number of retries to attempt to complete the task.
verbose: whether to print more logs.
report_progress_callback: a callback to report the progress of the agent. This is useful for streaming logs in a web application where multiple VisionAgent instances are running in parallel. This callback ensures that the progress are not mixed up.
"""
self.task_model = (
OpenAILLM(json_mode=True, temperature=0.1)
if task_model is None
Expand All @@ -433,8 +371,8 @@ def __init__(
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
)
self.max_retries = max_retries

self.tools = TOOLS
self.report_progress_callback = report_progress_callback
if verbose:
_LOGGER.setLevel(logging.INFO)

Expand All @@ -457,6 +395,11 @@ def __call__(
input = [{"role": "user", "content": input}]
return self.chat(input, image=image)

def log_progress(self, description: str) -> None:
_LOGGER.info(description)
if self.report_progress_callback:
self.report_progress_callback(description)

def chat_with_workflow(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> Tuple[str, List[Dict]]:
Expand All @@ -469,7 +412,9 @@ def chat_with_workflow(
all_tool_results: List[Dict] = []

for _ in range(self.max_retries):
task_list = create_tasks(self.task_model, question, self.tools, reflections)
task_list = self.create_tasks(
self.task_model, question, self.tools, reflections
)

task_depend = {"Original Quesiton": question}
previous_log = ""
Expand All @@ -481,7 +426,7 @@ def chat_with_workflow(
for task in task_list:
task_str = task["task"]
previous_log = str(task_depend)
tool_results, call_results = retrieval(
tool_results, call_results = self.retrieval(
self.task_model,
task_str,
self.tools,
Expand All @@ -495,8 +440,8 @@ def chat_with_workflow(
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tCall Result: {call_results}")
_LOGGER.info(f"\tAnswer: {answer}")
self.log_progress(f"\tCall Result: {call_results}")
self.log_progress(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
task_depend[task["id"]]["answer"] = answer # type: ignore
task_depend[task["id"]]["call_result"] = call_results # type: ignore
Expand All @@ -514,16 +459,103 @@ def chat_with_workflow(
final_answer,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"Reflection: {reflection}")
self.log_progress(f"Reflection: {reflection}")
if parse_reflect(reflection):
break
else:
reflections += reflection

# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
)
return final_answer, all_tool_results

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
answer, _ = self.chat_with_workflow(chat, image=image)
return answer

def retrieval(
self,
model: Union[LLM, LMM, Agent],
question: str,
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model,
question,
{k: v["description"] for k, v in tools.items()},
reflections,
)
if tool_id is None:
return {}, ""

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
tool_name = tool_instructions["name"]

parameters = choose_parameter(
model, question, tool_usage, previous_log, reflections
)
if parameters is None:
return {}, ""
tool_results = {
"task": question,
"tool_name": tool_name,
"parameters": parameters,
}

self.log_progress(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
if isinstance(result["parameters"], Dict):
call_results.append(
function_call(tools[tool_id]["class"], result["parameters"])
)
elif isinstance(result["parameters"], List):
for parameters in result["parameters"]:
call_results.append(
function_call(tools[tool_id]["class"], parameters)
)
return call_results

call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = str(call_results)
return tool_results, call_results_str

def create_tasks(
self,
task_model: Union[LLM, LMM],
question: str,
tools: Dict[int, Any],
reflections: str,
) -> List[Dict]:
tasks = task_decompose(
task_model,
question,
{k: v["description"] for k, v in tools.items()},
reflections,
)
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(task_model, question, task_list)
try:
task_list = topological_sort(task_list)
except Exception:
_LOGGER.error(f"Failed topological_sort on: {task_list}")
else:
task_list = []
self.log_progress(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
)
return task_list
2 changes: 1 addition & 1 deletion vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
data = Image.open(data)
if isinstance(data, Image.Image):
buffer = BytesIO()
data.save(buffer, format="PNG")
data.convert("RGB").save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
Expand Down

0 comments on commit e8ec297

Please sign in to comment.