From c481f9a41f5fd2c15199f271abc4d11919ead213 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 23 Mar 2024 16:39:49 -0700 Subject: [PATCH 1/2] add topo sort --- tests/test_easytool.py | 32 ++++++++++++++++++++++++++++++++ vision_agent/agent/easytool.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/test_easytool.py diff --git a/tests/test_easytool.py b/tests/test_easytool.py new file mode 100644 index 00000000..e2941b71 --- /dev/null +++ b/tests/test_easytool.py @@ -0,0 +1,32 @@ +from vision_agent.agent.easytool import topological_sort + + +def test_basic(): + tasks = [ + {"id": 1, "dep": [-1]}, + {"id": 3, "dep": [2]}, + {"id": 2, "dep": [1]}, + ] + assert topological_sort(tasks) == [tasks[0], tasks[2], tasks[1]] + + +def test_two_start(): + tasks = [ + {"id": 1, "dep": [-1]}, + {"id": 2, "dep": [1]}, + {"id": 3, "dep": [-1]}, + {"id": 4, "dep": [3]}, + {"id": 5, "dep": [2, 4]}, + ] + assert topological_sort(tasks) == [tasks[0], tasks[2], tasks[1], tasks[3], tasks[4]] + + +def test_broken(): + tasks = [ + {"id": 1, "dep": [-1]}, + {"id": 2, "dep": [3]}, + {"id": 3, "dep": [2]}, + {"id": 4, "dep": [3]}, + ] + + assert topological_sort(tasks) == [tasks[0], tasks[1], tasks[2], tasks[3]] diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index 6a3c6c77..7a7cf7d1 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -49,6 +49,33 @@ def format_tools(tools: Dict[int, Any]) -> str: return tool_str +def topological_sort(tasks: List[Dict]) -> List[Dict]: + in_degree = {task["id"]: 0 for task in tasks} + for task in tasks: + for dep in task["dep"]: + if dep in in_degree: + in_degree[task["id"]] += 1 + + queue = [task for task in tasks if in_degree[task["id"]] == 0] + sorted_order = [] + + while queue: + current = queue.pop(0) + sorted_order.append(current) + + for task in tasks: + if current["id"] in task["dep"]: + in_degree[task["id"]] -= 1 + if in_degree[task["id"]] == 0: + queue.append(task) + + if len(sorted_order) != len(tasks): + completed_ids = set([task["id"] for task in sorted_order]) + remaining_tasks = [task for task in tasks if task["id"] not in completed_ids] + sorted_order.extend(remaining_tasks) + return sorted_order + + def task_decompose( model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any] ) -> Optional[Dict]: @@ -265,6 +292,10 @@ def chat_with_workflow( if tasks is not None: task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)] task_list = task_topology(self.task_model, question, task_list) + try: + task_list = topological_sort(task_list) + except Exception: + _LOGGER.error(f"Failed topological_sort on: {task_list}") else: task_list = [] @@ -274,7 +305,6 @@ def chat_with_workflow( answers = [] for task in task_list: task_depend[task["id"]] = {"task": task["task"], "answer": ""} # type: ignore - # TODO topological sort task_list all_tool_results = [] for task in task_list: task_str = task["task"] From 3ed5341a94015a84cbb119438e1debcce88a1988 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 23 Mar 2024 16:45:39 -0700 Subject: [PATCH 2/2] update lock file --- poetry.lock | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/poetry.lock b/poetry.lock index d765d6ae..ce0e26e1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -354,13 +354,13 @@ pyflakes = ">=2.5.0,<2.6.0" [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.3.1" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"}, + {file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"}, ] [package.extras] @@ -406,13 +406,13 @@ dev = ["flake8", "markdown", "twine", "wheel"] [[package]] name = "griffe" -version = "0.42.0" +version = "0.42.1" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." optional = false python-versions = ">=3.8" files = [ - {file = "griffe-0.42.0-py3-none-any.whl", hash = "sha256:384df6b802a60f70e65fdb7e83f5b27e2da869a12eac85b25b55250012dbc263"}, - {file = "griffe-0.42.0.tar.gz", hash = "sha256:fb83ee602701ffdf99c9a6bf5f0a5a3bd877364b3bffb2c451dc8fbd9645b0cf"}, + {file = "griffe-0.42.1-py3-none-any.whl", hash = "sha256:7e805e35617601355edcac0d3511cedc1ed0cb1f7645e2d336ae4b05bbae7b3b"}, + {file = "griffe-0.42.1.tar.gz", hash = "sha256:57046131384043ed078692b85d86b76568a686266cc036b9b56b704466f803ce"}, ] [package.dependencies] @@ -573,13 +573,13 @@ files = [ [[package]] name = "markdown" -version = "3.5.2" +version = "3.6" description = "Python implementation of John Gruber's Markdown." optional = false python-versions = ">=3.8" files = [ - {file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"}, - {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, ] [package.extras] @@ -725,13 +725,13 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-material" -version = "9.5.13" +version = "9.5.15" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.5.13-py3-none-any.whl", hash = "sha256:5cbe17fee4e3b4980c8420a04cc762d8dc052ef1e10532abd4fce88e5ea9ce6a"}, - {file = "mkdocs_material-9.5.13.tar.gz", hash = "sha256:d8e4caae576312a88fd2609b81cf43d233cdbe36860d67a68702b018b425bd87"}, + {file = "mkdocs_material-9.5.15-py3-none-any.whl", hash = "sha256:e5c96dec3d19491de49ca643fc1dbb92b278e43cdb816c775bc47db77d9b62fb"}, + {file = "mkdocs_material-9.5.15.tar.gz", hash = "sha256:39f03cca45e82bf54eb7456b5a18bd252eabfdd67f237a229471484a0a4d4635"}, ] [package.dependencies] @@ -1084,13 +1084,13 @@ files = [ [[package]] name = "openai" -version = "1.13.3" +version = "1.14.2" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.13.3-py3-none-any.whl", hash = "sha256:5769b62abd02f350a8dd1a3a242d8972c947860654466171d60fb0972ae0a41c"}, - {file = "openai-1.13.3.tar.gz", hash = "sha256:ff6c6b3bc7327e715e4b3592a923a5a1c7519ff5dd764a83d69f633d49e77a7b"}, + {file = "openai-1.14.2-py3-none-any.whl", hash = "sha256:a48b3c4d635b603952189ac5a0c0c9b06c025b80eb2900396939f02bb2104ac3"}, + {file = "openai-1.14.2.tar.gz", hash = "sha256:e5642f7c02cf21994b08477d7bb2c1e46d8f335d72c26f0396c5f89b15b5b153"}, ] [package.dependencies] @@ -1971,13 +1971,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", [[package]] name = "sentence-transformers" -version = "2.5.1" +version = "2.6.0" description = "Multilingual text embeddings" optional = false python-versions = ">=3.8.0" files = [ - {file = "sentence-transformers-2.5.1.tar.gz", hash = "sha256:754bf2b2623eb46904fd9c72ff89a0f90200fe141a8d45b03e83bc6d51718153"}, - {file = "sentence_transformers-2.5.1-py3-none-any.whl", hash = "sha256:f12346f7fca06ed1198d24235cb9114a74665506f7c30044e0a6f12de7eeeb77"}, + {file = "sentence-transformers-2.6.0.tar.gz", hash = "sha256:cf519d311ddcc8ff84d78d18fae051985bef348716354e069e3d6d670e4db604"}, + {file = "sentence_transformers-2.6.0-py3-none-any.whl", hash = "sha256:8807db6db0cf0a92f02be799b7c8260029ff91315d18f7b4b51c30b10d6b4fdb"}, ] [package.dependencies] @@ -2044,13 +2044,13 @@ mpmath = ">=0.19" [[package]] name = "threadpoolctl" -version = "3.3.0" +version = "3.4.0" description = "threadpoolctl" optional = false python-versions = ">=3.8" files = [ - {file = "threadpoolctl-3.3.0-py3-none-any.whl", hash = "sha256:6155be1f4a39f31a18ea70f94a77e0ccd57dced08122ea61109e7da89883781e"}, - {file = "threadpoolctl-3.3.0.tar.gz", hash = "sha256:5dac632b4fa2d43f42130267929af3ba01399ef4bd1882918e92dbc30365d30c"}, + {file = "threadpoolctl-3.4.0-py3-none-any.whl", hash = "sha256:8f4c689a65b23e5ed825c8436a92b818aac005e0f3715f6a1664d7c7ee29d262"}, + {file = "threadpoolctl-3.4.0.tar.gz", hash = "sha256:f11b491a03661d6dd7ef692dd422ab34185d982466c49c8f98c8f716b5c93196"}, ] [[package]] @@ -2266,13 +2266,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.38.2" +version = "4.39.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.38.2-py3-none-any.whl", hash = "sha256:c4029cb9f01b3dd335e52f364c52d2b37c65b4c78e02e6a08b1919c5c928573e"}, - {file = "transformers-4.38.2.tar.gz", hash = "sha256:c5fc7ad682b8a50a48b2a4c05d4ea2de5567adb1bdd00053619dbe5960857dd5"}, + {file = "transformers-4.39.1-py3-none-any.whl", hash = "sha256:df167e08b27ab254044a38bb7c439461cd3916332205416e9b6b1592b517a1a5"}, + {file = "transformers-4.39.1.tar.gz", hash = "sha256:ab9c1e1912843b9976e6cc62b27cd5434284fc0dab465e1b660333acfa81c6bc"}, ] [package.dependencies]