Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support streaming chat logs of an agent #47

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vision_agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ def __call__(
image: Optional[Union[str, Path]] = None,
) -> str:
pass

@abstractmethod
def log_progress(self, description: str) -> None:
"""Log the progress of the agent.
This is a hook that is intended for reporting the progress of the agent.
"""
pass
194 changes: 113 additions & 81 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,79 +244,6 @@ def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any:
return str(e)


def retrieval(
model: Union[LLM, LMM, Agent],
question: str,
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}, reflections
)
if tool_id is None:
return {}, ""

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
tool_name = tool_instructions["name"]

parameters = choose_parameter(
model, question, tool_usage, previous_log, reflections
)
if parameters is None:
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}

_LOGGER.info(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
if isinstance(result["parameters"], Dict):
call_results.append(
function_call(tools[tool_id]["class"], result["parameters"])
)
elif isinstance(result["parameters"], List):
for parameters in result["parameters"]:
call_results.append(function_call(tools[tool_id]["class"], parameters))
return call_results

call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = str(call_results)
# _LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str


def create_tasks(
task_model: Union[LLM, LMM], question: str, tools: Dict[int, Any], reflections: str
) -> List[Dict]:
tasks = task_decompose(
task_model,
question,
{k: v["description"] for k, v in tools.items()},
reflections,
)
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(task_model, question, task_list)
try:
task_list = topological_sort(task_list)
except Exception:
_LOGGER.error(f"Failed topological_sort on: {task_list}")
else:
task_list = []
_LOGGER.info(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
)
return task_list


def self_reflect(
reflect_model: Union[LLM, LMM],
question: str,
Expand Down Expand Up @@ -350,7 +277,7 @@ def parse_reflect(reflect: str) -> bool:
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]:
continue

parameters = tool_result["parameters"]
Expand Down Expand Up @@ -420,7 +347,18 @@ def __init__(
reflect_model: Optional[Union[LLM, LMM]] = None,
max_retries: int = 2,
verbose: bool = False,
report_progress_callback: Optional[Callable[[str], None]] = None,
):
"""VisionAgent constructor.

Parameters
task_model: the model to use for task decomposition.
answer_model: the model to use for reasoning and concluding the answer.
reflect_model: the model to use for self reflection.
max_retries: maximum number of retries to attempt to complete the task.
verbose: whether to print more logs.
report_progress_callback: a callback to report the progress of the agent. This is useful for streaming logs in a web application where multiple VisionAgent instances are running in parallel. This callback ensures that the progress are not mixed up.
"""
self.task_model = (
OpenAILLM(json_mode=True, temperature=0.1)
if task_model is None
Expand All @@ -433,8 +371,8 @@ def __init__(
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
)
self.max_retries = max_retries

self.tools = TOOLS
self.report_progress_callback = report_progress_callback
if verbose:
_LOGGER.setLevel(logging.INFO)

Expand All @@ -457,6 +395,11 @@ def __call__(
input = [{"role": "user", "content": input}]
return self.chat(input, image=image)

def log_progress(self, description: str) -> None:
_LOGGER.info(description)
if self.report_progress_callback:
self.report_progress_callback(description)

def chat_with_workflow(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> Tuple[str, List[Dict]]:
Expand All @@ -469,7 +412,9 @@ def chat_with_workflow(
all_tool_results: List[Dict] = []

for _ in range(self.max_retries):
task_list = create_tasks(self.task_model, question, self.tools, reflections)
task_list = self.create_tasks(
self.task_model, question, self.tools, reflections
)

task_depend = {"Original Quesiton": question}
previous_log = ""
Expand All @@ -481,7 +426,7 @@ def chat_with_workflow(
for task in task_list:
task_str = task["task"]
previous_log = str(task_depend)
tool_results, call_results = retrieval(
tool_results, call_results = self.retrieval(
self.task_model,
task_str,
self.tools,
Expand All @@ -495,8 +440,8 @@ def chat_with_workflow(
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tCall Result: {call_results}")
_LOGGER.info(f"\tAnswer: {answer}")
self.log_progress(f"\tCall Result: {call_results}")
self.log_progress(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
task_depend[task["id"]]["answer"] = answer # type: ignore
task_depend[task["id"]]["call_result"] = call_results # type: ignore
Expand All @@ -514,16 +459,103 @@ def chat_with_workflow(
final_answer,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"Reflection: {reflection}")
self.log_progress(f"Reflection: {reflection}")
if parse_reflect(reflection):
break
else:
reflections += reflection

# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
)
return final_answer, all_tool_results

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
answer, _ = self.chat_with_workflow(chat, image=image)
return answer

def retrieval(
self,
model: Union[LLM, LMM, Agent],
question: str,
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model,
question,
{k: v["description"] for k, v in tools.items()},
reflections,
)
if tool_id is None:
return {}, ""

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
tool_name = tool_instructions["name"]

parameters = choose_parameter(
model, question, tool_usage, previous_log, reflections
)
if parameters is None:
return {}, ""
tool_results = {
"task": question,
"tool_name": tool_name,
"parameters": parameters,
}

self.log_progress(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
if isinstance(result["parameters"], Dict):
call_results.append(
function_call(tools[tool_id]["class"], result["parameters"])
)
elif isinstance(result["parameters"], List):
for parameters in result["parameters"]:
call_results.append(
function_call(tools[tool_id]["class"], parameters)
)
return call_results

call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = str(call_results)
return tool_results, call_results_str

def create_tasks(
self,
task_model: Union[LLM, LMM],
question: str,
tools: Dict[int, Any],
reflections: str,
) -> List[Dict]:
tasks = task_decompose(
task_model,
question,
{k: v["description"] for k, v in tools.items()},
reflections,
)
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(task_model, question, task_list)
try:
task_list = topological_sort(task_list)
except Exception:
_LOGGER.error(f"Failed topological_sort on: {task_list}")
else:
task_list = []
self.log_progress(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
)
return task_list
2 changes: 1 addition & 1 deletion vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
data = Image.open(data)
if isinstance(data, Image.Image):
buffer = BytesIO()
data.save(buffer, format="PNG")
data.convert("RGB").save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
Expand Down
Loading