Skip to content

Commit

Permalink
add topo sort
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 23, 2024
1 parent 0747ce9 commit c481f9a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
32 changes: 32 additions & 0 deletions tests/test_easytool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from vision_agent.agent.easytool import topological_sort


def test_basic():
tasks = [
{"id": 1, "dep": [-1]},
{"id": 3, "dep": [2]},
{"id": 2, "dep": [1]},
]
assert topological_sort(tasks) == [tasks[0], tasks[2], tasks[1]]


def test_two_start():
tasks = [
{"id": 1, "dep": [-1]},
{"id": 2, "dep": [1]},
{"id": 3, "dep": [-1]},
{"id": 4, "dep": [3]},
{"id": 5, "dep": [2, 4]},
]
assert topological_sort(tasks) == [tasks[0], tasks[2], tasks[1], tasks[3], tasks[4]]


def test_broken():
tasks = [
{"id": 1, "dep": [-1]},
{"id": 2, "dep": [3]},
{"id": 3, "dep": [2]},
{"id": 4, "dep": [3]},
]

assert topological_sort(tasks) == [tasks[0], tasks[1], tasks[2], tasks[3]]
32 changes: 31 additions & 1 deletion vision_agent/agent/easytool.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def format_tools(tools: Dict[int, Any]) -> str:
return tool_str


def topological_sort(tasks: List[Dict]) -> List[Dict]:
in_degree = {task["id"]: 0 for task in tasks}
for task in tasks:
for dep in task["dep"]:
if dep in in_degree:
in_degree[task["id"]] += 1

queue = [task for task in tasks if in_degree[task["id"]] == 0]
sorted_order = []

while queue:
current = queue.pop(0)
sorted_order.append(current)

for task in tasks:
if current["id"] in task["dep"]:
in_degree[task["id"]] -= 1
if in_degree[task["id"]] == 0:
queue.append(task)

if len(sorted_order) != len(tasks):
completed_ids = set([task["id"] for task in sorted_order])
remaining_tasks = [task for task in tasks if task["id"] not in completed_ids]
sorted_order.extend(remaining_tasks)
return sorted_order


def task_decompose(
model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any]
) -> Optional[Dict]:
Expand Down Expand Up @@ -265,6 +292,10 @@ def chat_with_workflow(
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(self.task_model, question, task_list)
try:
task_list = topological_sort(task_list)
except Exception:
_LOGGER.error(f"Failed topological_sort on: {task_list}")
else:
task_list = []

Expand All @@ -274,7 +305,6 @@ def chat_with_workflow(
answers = []
for task in task_list:
task_depend[task["id"]] = {"task": task["task"], "answer": ""} # type: ignore
# TODO topological sort task_list
all_tool_results = []
for task in task_list:
task_str = task["task"]
Expand Down

0 comments on commit c481f9a

Please sign in to comment.