From c481f9a41f5fd2c15199f271abc4d11919ead213 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 23 Mar 2024 16:39:49 -0700 Subject: [PATCH] add topo sort --- tests/test_easytool.py | 32 ++++++++++++++++++++++++++++++++ vision_agent/agent/easytool.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/test_easytool.py diff --git a/tests/test_easytool.py b/tests/test_easytool.py new file mode 100644 index 00000000..e2941b71 --- /dev/null +++ b/tests/test_easytool.py @@ -0,0 +1,32 @@ +from vision_agent.agent.easytool import topological_sort + + +def test_basic(): + tasks = [ + {"id": 1, "dep": [-1]}, + {"id": 3, "dep": [2]}, + {"id": 2, "dep": [1]}, + ] + assert topological_sort(tasks) == [tasks[0], tasks[2], tasks[1]] + + +def test_two_start(): + tasks = [ + {"id": 1, "dep": [-1]}, + {"id": 2, "dep": [1]}, + {"id": 3, "dep": [-1]}, + {"id": 4, "dep": [3]}, + {"id": 5, "dep": [2, 4]}, + ] + assert topological_sort(tasks) == [tasks[0], tasks[2], tasks[1], tasks[3], tasks[4]] + + +def test_broken(): + tasks = [ + {"id": 1, "dep": [-1]}, + {"id": 2, "dep": [3]}, + {"id": 3, "dep": [2]}, + {"id": 4, "dep": [3]}, + ] + + assert topological_sort(tasks) == [tasks[0], tasks[1], tasks[2], tasks[3]] diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index 6a3c6c77..7a7cf7d1 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -49,6 +49,33 @@ def format_tools(tools: Dict[int, Any]) -> str: return tool_str +def topological_sort(tasks: List[Dict]) -> List[Dict]: + in_degree = {task["id"]: 0 for task in tasks} + for task in tasks: + for dep in task["dep"]: + if dep in in_degree: + in_degree[task["id"]] += 1 + + queue = [task for task in tasks if in_degree[task["id"]] == 0] + sorted_order = [] + + while queue: + current = queue.pop(0) + sorted_order.append(current) + + for task in tasks: + if current["id"] in task["dep"]: + in_degree[task["id"]] -= 1 + if in_degree[task["id"]] == 0: + queue.append(task) + + if len(sorted_order) != len(tasks): + completed_ids = set([task["id"] for task in sorted_order]) + remaining_tasks = [task for task in tasks if task["id"] not in completed_ids] + sorted_order.extend(remaining_tasks) + return sorted_order + + def task_decompose( model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any] ) -> Optional[Dict]: @@ -265,6 +292,10 @@ def chat_with_workflow( if tasks is not None: task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)] task_list = task_topology(self.task_model, question, task_list) + try: + task_list = topological_sort(task_list) + except Exception: + _LOGGER.error(f"Failed topological_sort on: {task_list}") else: task_list = [] @@ -274,7 +305,6 @@ def chat_with_workflow( answers = [] for task in task_list: task_depend[task["id"]] = {"task": task["task"], "answer": ""} # type: ignore - # TODO topological sort task_list all_tool_results = [] for task in task_list: task_str = task["task"]