From 3b91fa94c283dbc222965018f731b9ee9a831782 Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:54:43 -0700 Subject: [PATCH] Support streaming chat logs of an agent (#47) Add a callback for reporting the chat progress of an agent --- vision_agent/agent/agent.py | 7 ++ vision_agent/agent/vision_agent.py | 194 +++++++++++++++++------------ vision_agent/image_utils.py | 2 +- 3 files changed, 121 insertions(+), 82 deletions(-) diff --git a/vision_agent/agent/agent.py b/vision_agent/agent/agent.py index 5054f170..93b3223d 100644 --- a/vision_agent/agent/agent.py +++ b/vision_agent/agent/agent.py @@ -11,3 +11,10 @@ def __call__( image: Optional[Union[str, Path]] = None, ) -> str: pass + + @abstractmethod + def log_progress(self, description: str) -> None: + """Log the progress of the agent. + This is a hook that is intended for reporting the progress of the agent. + """ + pass diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index a072903c..bad0933a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -244,79 +244,6 @@ def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any: return str(e) -def retrieval( - model: Union[LLM, LMM, Agent], - question: str, - tools: Dict[int, Any], - previous_log: str, - reflections: str, -) -> Tuple[Dict, str]: - tool_id = choose_tool( - model, question, {k: v["description"] for k, v in tools.items()}, reflections - ) - if tool_id is None: - return {}, "" - - tool_instructions = tools[tool_id] - tool_usage = tool_instructions["usage"] - tool_name = tool_instructions["name"] - - parameters = choose_parameter( - model, question, tool_usage, previous_log, reflections - ) - if parameters is None: - return {}, "" - tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters} - - _LOGGER.info( - f"""Going to run the following tool(s) in sequence: -{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}""" - ) - - def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: - call_results: List[Any] = [] - if isinstance(result["parameters"], Dict): - call_results.append( - function_call(tools[tool_id]["class"], result["parameters"]) - ) - elif isinstance(result["parameters"], List): - for parameters in result["parameters"]: - call_results.append(function_call(tools[tool_id]["class"], parameters)) - return call_results - - call_results = parse_tool_results(tool_results) - tool_results["call_results"] = call_results - - call_results_str = str(call_results) - # _LOGGER.info(f"\tCall Results: {call_results_str}") - return tool_results, call_results_str - - -def create_tasks( - task_model: Union[LLM, LMM], question: str, tools: Dict[int, Any], reflections: str -) -> List[Dict]: - tasks = task_decompose( - task_model, - question, - {k: v["description"] for k, v in tools.items()}, - reflections, - ) - if tasks is not None: - task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)] - task_list = task_topology(task_model, question, task_list) - try: - task_list = topological_sort(task_list) - except Exception: - _LOGGER.error(f"Failed topological_sort on: {task_list}") - else: - task_list = [] - _LOGGER.info( - f"""Planned tasks: -{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}""" - ) - return task_list - - def self_reflect( reflect_model: Union[LLM, LMM], question: str, @@ -350,7 +277,7 @@ def parse_reflect(reflect: str) -> bool: def visualize_result(all_tool_results: List[Dict]) -> List[str]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: - if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]: + if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]: continue parameters = tool_result["parameters"] @@ -420,7 +347,18 @@ def __init__( reflect_model: Optional[Union[LLM, LMM]] = None, max_retries: int = 2, verbose: bool = False, + report_progress_callback: Optional[Callable[[str], None]] = None, ): + """VisionAgent constructor. + + Parameters + task_model: the model to use for task decomposition. + answer_model: the model to use for reasoning and concluding the answer. + reflect_model: the model to use for self reflection. + max_retries: maximum number of retries to attempt to complete the task. + verbose: whether to print more logs. + report_progress_callback: a callback to report the progress of the agent. This is useful for streaming logs in a web application where multiple VisionAgent instances are running in parallel. This callback ensures that the progress are not mixed up. + """ self.task_model = ( OpenAILLM(json_mode=True, temperature=0.1) if task_model is None @@ -433,8 +371,8 @@ def __init__( OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model ) self.max_retries = max_retries - self.tools = TOOLS + self.report_progress_callback = report_progress_callback if verbose: _LOGGER.setLevel(logging.INFO) @@ -457,6 +395,11 @@ def __call__( input = [{"role": "user", "content": input}] return self.chat(input, image=image) + def log_progress(self, description: str) -> None: + _LOGGER.info(description) + if self.report_progress_callback: + self.report_progress_callback(description) + def chat_with_workflow( self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None ) -> Tuple[str, List[Dict]]: @@ -469,7 +412,9 @@ def chat_with_workflow( all_tool_results: List[Dict] = [] for _ in range(self.max_retries): - task_list = create_tasks(self.task_model, question, self.tools, reflections) + task_list = self.create_tasks( + self.task_model, question, self.tools, reflections + ) task_depend = {"Original Quesiton": question} previous_log = "" @@ -481,7 +426,7 @@ def chat_with_workflow( for task in task_list: task_str = task["task"] previous_log = str(task_depend) - tool_results, call_results = retrieval( + tool_results, call_results = self.retrieval( self.task_model, task_str, self.tools, @@ -495,8 +440,8 @@ def chat_with_workflow( tool_results["answer"] = answer all_tool_results.append(tool_results) - _LOGGER.info(f"\tCall Result: {call_results}") - _LOGGER.info(f"\tAnswer: {answer}") + self.log_progress(f"\tCall Result: {call_results}") + self.log_progress(f"\tAnswer: {answer}") answers.append({"task": task_str, "answer": answer}) task_depend[task["id"]]["answer"] = answer # type: ignore task_depend[task["id"]]["call_result"] = call_results # type: ignore @@ -514,12 +459,15 @@ def chat_with_workflow( final_answer, visualized_images[0] if len(visualized_images) > 0 else image, ) - _LOGGER.info(f"Reflection: {reflection}") + self.log_progress(f"Reflection: {reflection}") if parse_reflect(reflection): break else: reflections += reflection - + # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. + self.log_progress( + f"The Vision Agent has concluded this chat. {final_answer}" + ) return final_answer, all_tool_results def chat( @@ -527,3 +475,87 @@ def chat( ) -> str: answer, _ = self.chat_with_workflow(chat, image=image) return answer + + def retrieval( + self, + model: Union[LLM, LMM, Agent], + question: str, + tools: Dict[int, Any], + previous_log: str, + reflections: str, + ) -> Tuple[Dict, str]: + tool_id = choose_tool( + model, + question, + {k: v["description"] for k, v in tools.items()}, + reflections, + ) + if tool_id is None: + return {}, "" + + tool_instructions = tools[tool_id] + tool_usage = tool_instructions["usage"] + tool_name = tool_instructions["name"] + + parameters = choose_parameter( + model, question, tool_usage, previous_log, reflections + ) + if parameters is None: + return {}, "" + tool_results = { + "task": question, + "tool_name": tool_name, + "parameters": parameters, + } + + self.log_progress( + f"""Going to run the following tool(s) in sequence: +{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}""" + ) + + def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: + call_results: List[Any] = [] + if isinstance(result["parameters"], Dict): + call_results.append( + function_call(tools[tool_id]["class"], result["parameters"]) + ) + elif isinstance(result["parameters"], List): + for parameters in result["parameters"]: + call_results.append( + function_call(tools[tool_id]["class"], parameters) + ) + return call_results + + call_results = parse_tool_results(tool_results) + tool_results["call_results"] = call_results + + call_results_str = str(call_results) + return tool_results, call_results_str + + def create_tasks( + self, + task_model: Union[LLM, LMM], + question: str, + tools: Dict[int, Any], + reflections: str, + ) -> List[Dict]: + tasks = task_decompose( + task_model, + question, + {k: v["description"] for k, v in tools.items()}, + reflections, + ) + if tasks is not None: + task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)] + task_list = task_topology(task_model, question, task_list) + try: + task_list = topological_sort(task_list) + except Exception: + _LOGGER.error(f"Failed topological_sort on: {task_list}") + else: + task_list = [] + self.log_progress( + f"""Planned tasks: +{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}""" + ) + return task_list diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 849f912f..65ee5b01 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -78,7 +78,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: data = Image.open(data) if isinstance(data, Image.Image): buffer = BytesIO() - data.save(buffer, format="PNG") + data.convert("RGB").save(buffer, format="JPEG") return base64.b64encode(buffer.getvalue()).decode("utf-8") else: arr_bytes = data.tobytes()