From a7f0538b2c971c59a6f934b4446401efc290ff42 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 1 Aug 2024 09:06:00 +0200 Subject: [PATCH 01/11] Fix: handle exception when inspecting the function (#196) --- aiida_workgraph/utils/__init__.py | 52 +++++++++++++++++-------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 10406c27..a8501459 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -565,31 +565,37 @@ def serialize_function(func: Callable) -> Dict[str, Any]: import textwrap import cloudpickle as pickle - # we need save the source code explicitly, because in the case of jupyter notebook, - # the source code is not saved in the pickle file - source_code = inspect.getsource(func) - # Split the source into lines for processing - source_code_lines = source_code.split("\n") - function_source_code = "\n".join(source_code_lines) - # Find the first line of the actual function definition - for i, line in enumerate(source_code_lines): - if line.strip().startswith("def "): - break - function_source_code_without_decorator = "\n".join(source_code_lines[i:]) - function_source_code_without_decorator = textwrap.dedent( - function_source_code_without_decorator - ) - # we also need to include the necessary imports for the types used in the type hints. try: - required_imports = get_required_imports(func) + # we need save the source code explicitly, because in the case of jupyter notebook, + # the source code is not saved in the pickle file + source_code = inspect.getsource(func) + # Split the source into lines for processing + source_code_lines = source_code.split("\n") + function_source_code = "\n".join(source_code_lines) + # Find the first line of the actual function definition + for i, line in enumerate(source_code_lines): + if line.strip().startswith("def "): + break + function_source_code_without_decorator = "\n".join(source_code_lines[i:]) + function_source_code_without_decorator = textwrap.dedent( + function_source_code_without_decorator + ) + # we also need to include the necessary imports for the types used in the type hints. + try: + required_imports = get_required_imports(func) + except Exception as e: + required_imports = {} + print(f"Failed to get required imports for function {func.__name__}: {e}") + # Generate import statements + import_statements = "\n".join( + f"from {module} import {', '.join(types)}" + for module, types in required_imports.items() + ) except Exception as e: - required_imports = {} - print(f"Failed to get required imports for function {func.__name__}: {e}") - # Generate import statements - import_statements = "\n".join( - f"from {module} import {', '.join(types)}" - for module, types in required_imports.items() - ) + print(f"Failed to serialize function {func.__name__}: {e}") + function_source_code = "" + function_source_code_without_decorator = "" + import_statements = "" return { "executor": pickle.dumps(func), "type": "function", From d41a8ae1aea4195e2ff0092b081a22221768f0d3 Mon Sep 17 00:00:00 2001 From: Ali Khosravi Date: Thu, 8 Aug 2024 15:19:14 +0200 Subject: [PATCH 02/11] Cleaning unused html files that cause problems for case-insensitive filesystems (#201) --- .../html/pythonjob_parent_folder.html | 258 ------------------ .../html/pythonjob_shell_command.html | 258 ------------------ 2 files changed, 516 deletions(-) delete mode 100644 docs/source/built-in/html/pythonjob_parent_folder.html delete mode 100644 docs/source/built-in/html/pythonjob_shell_command.html diff --git a/docs/source/built-in/html/pythonjob_parent_folder.html b/docs/source/built-in/html/pythonjob_parent_folder.html deleted file mode 100644 index bf480f82..00000000 --- a/docs/source/built-in/html/pythonjob_parent_folder.html +++ /dev/null @@ -1,258 +0,0 @@ - - - - - - - Rete.js with React in Vanilla JS - - - - - - - - - - - - - - - - - - - - -
- - - diff --git a/docs/source/built-in/html/pythonjob_shell_command.html b/docs/source/built-in/html/pythonjob_shell_command.html deleted file mode 100644 index e4cab961..00000000 --- a/docs/source/built-in/html/pythonjob_shell_command.html +++ /dev/null @@ -1,258 +0,0 @@ - - - - - - - Rete.js with React in Vanilla JS - - - - - - - - - - - - - - - - - - - - -
- - - From 262b744a5c7d54ae4f775f6cd71cff8593b19d26 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 8 Aug 2024 19:36:45 +0200 Subject: [PATCH 03/11] Add built-in `While` task (#198) * Add built-in task: **While**. Task allows for the running of tasks many times in a loop. The While task has parameters: - **max_iterations** - **conditions**: a list of variables from the `context`. When any of them is `False`, the while loop exit. - **tasks**: a list of names of the the tasks inside the while loop. * Update the Web app and widget to support visualizing the `While` task as a parent scope task. --- aiida_workgraph/collection.py | 3 + aiida_workgraph/engine/workgraph.py | 344 ++-- aiida_workgraph/task.py | 2 +- aiida_workgraph/tasks/__init__.py | 3 +- aiida_workgraph/tasks/builtin.py | 20 + aiida_workgraph/utils/__init__.py | 18 +- aiida_workgraph/utils/analysis.py | 23 + .../web/frontend/package-lock.json | 383 ++-- aiida_workgraph/web/frontend/package.json | 24 +- aiida_workgraph/web/frontend/src/App.js | 1 - .../web/frontend/src/rete/default.ts | 63 +- aiida_workgraph/widget/js/default_rete.ts | 28 +- aiida_workgraph/widget/package-lock.json | 23 +- aiida_workgraph/widget/package.json | 1 + .../widget/src/widget/html_template.py | 23 + aiida_workgraph/workgraph.py | 6 +- docs/source/howto/html/test_while_task.html | 281 +++ .../howto/html/test_while_workgraph.html | 281 +++ docs/source/howto/while.ipynb | 1824 +++++++++++++---- tests/test_decorator.py | 3 +- tests/test_while.py | 33 + tests/test_workgraph.py | 5 +- 22 files changed, 2580 insertions(+), 812 deletions(-) create mode 100644 docs/source/howto/html/test_while_task.html create mode 100644 docs/source/howto/html/test_while_workgraph.html diff --git a/aiida_workgraph/collection.py b/aiida_workgraph/collection.py index c89d3254..4104617f 100644 --- a/aiida_workgraph/collection.py +++ b/aiida_workgraph/collection.py @@ -38,6 +38,9 @@ def new( # make links between the tasks task.set(links) return task + if isinstance(identifier, str) and identifier.upper() == "WHILE": + task = super().new(identifier, name, uuid, **kwargs) + return task if isinstance(identifier, WorkGraph): identifier = build_task_from_workgraph(identifier) return super().new(identifier, name, uuid, **kwargs) diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index a51b3c60..081eaac1 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -32,6 +32,7 @@ from aiida.engine import run_get_node from aiida_workgraph.utils import create_and_pause_process from aiida_workgraph.task import Task +from aiida_workgraph.utils import get_nested_dict, update_nested_dict if t.TYPE_CHECKING: from aiida.engine.runners import Runner # pylint: disable=unused-import @@ -475,7 +476,8 @@ def get_task(self, name: str): return task def update_task(self, task: Task): - """Update task in the context.""" + """Update task in the context. + This is used in error handlers to update the task parameters.""" self.ctx.tasks[task.name]["properties"] = task.properties_to_dict() self.reset_task(task.name) @@ -517,55 +519,7 @@ def set_task_results(self) -> None: for name, task in self.ctx.tasks.items(): if self.get_task_state_info(name, "action").upper() == "RESET": self.reset_task(task["name"]) - process = self.get_task_state_info(name, "process") - if process: - self.set_task_result(task) - self.set_task_result(task) - - def set_task_result(self, task: t.Dict[str, t.Any]) -> None: - name = task["name"] - # print(f"set task result: {name}") - node = self.get_task_state_info(name, "process") - if isinstance(node, orm.ProcessNode): - # print(f"set task result: {name} process") - state = self.get_task_state_info( - task["name"], "process" - ).process_state.value.upper() - if node.is_finished_ok: - self.set_task_state_info(task["name"], "state", state) - if task["metadata"]["node_type"].upper() == "WORKGRAPH": - # expose the outputs of all the tasks in the workgraph - task["results"] = {} - outgoing = node.base.links.get_outgoing() - for link in outgoing.all(): - if isinstance(link.node, ProcessNode) and getattr( - link.node, "process_state", False - ): - task["results"][link.link_label] = link.node.outputs - else: - task["results"] = node.outputs - # self.ctx.new_data[name] = task["results"] - self.set_task_state_info(task["name"], "state", "FINISHED") - self.task_set_context(name) - self.report(f"Task: {name} finished.") - # all other states are considered as failed - else: - task["results"] = node.outputs - # self.ctx.new_data[name] = task["results"] - self.set_task_state_info(task["name"], "state", "FAILED") - # set child tasks state to SKIPPED - self.set_tasks_state( - self.ctx.connectivity["child_node"][name], "SKIPPED" - ) - self.report(f"Task: {name} failed.") - self.run_error_handlers(name) - elif isinstance(node, orm.Data): - task["results"] = {task["outputs"][0]["name"]: node} - self.set_task_state_info(task["name"], "state", "FINISHED") - self.task_set_context(name) - self.report(f"Task: {name} finished.") - else: - task["results"] = None + self.update_task_state(name) def apply_action(self, msg: dict) -> None: @@ -591,20 +545,24 @@ def apply_task_actions(self, msg: dict) -> None: if action.upper() == "SKIP": pass - def reset_task(self, name: str) -> None: - """Reset task.""" + def reset_task(self, name: str, recursive: bool = True) -> None: + """Reset task state and remove it from the executed task. + If recursive is True, reset its child tasks.""" - self.report(f"Task {name} action: RESET.") self.set_task_state_info(name, "state", "PLANNED") self.set_task_state_info(name, "process", None) self.remove_executed_task(name) - # reset its child tasks - names = self.ctx.connectivity["child_node"][name] - for name in names: - self.set_task_state_info(name, "state", "PLANNED") - self.ctx.tasks[name]["result"] = None - self.set_task_state_info(name, "process", None) - self.remove_executed_task(name) + if recursive: + self.report(f"Task {name} action: RESET.") + # if the task is a while task, reset its child tasks + if self.ctx.tasks[name]["metadata"]["node_type"].upper() == "WHILE": + self.ctx.tasks[name]["execution_count"] = 0 + for child_task in self.ctx.tasks[name]["properties"]["tasks"]["value"]: + self.reset_task(child_task, recursive=False) + # reset its child tasks + names = self.ctx.connectivity["child_node"][name] + for name in names: + self.reset_task(name, recursive=False) def pause_task(self, name: str) -> None: """Pause task.""" @@ -629,7 +587,7 @@ def continue_workgraph(self) -> None: "SKIPPED", ]: continue - ready, output = self.check_parent_state(name) + ready, _ = self.is_task_ready_to_run(name) if ready and self.task_should_run(name): task_to_run.append(name) # @@ -637,20 +595,98 @@ def continue_workgraph(self) -> None: self.run_tasks(task_to_run) def update_task_state(self, name: str) -> None: - """Update task state if task is a Awaitable.""" + """Update task state when the task is finished.""" print("update task state: ", name) task = self.ctx.tasks[name] - if task["metadata"]["node_type"].upper() in [ - "CALCFUNCTION", - "WORKFUNCTION", - "CALCJOB", - "WORKCHAIN", - "GRAPH_BUILDER", - "WORKGRAPH", - "PYTHONJOB", - "SHELLJOB", - ] and self.get_task_state_info(task["name"], "state") in ["CREATED", "RUNNING"]: - self.set_task_result(task) + # print(f"set task result: {name}") + node = self.get_task_state_info(name, "process") + if isinstance(node, orm.ProcessNode): + # print(f"set task result: {name} process") + state = node.process_state.value.upper() + if node.is_finished_ok: + self.set_task_state_info(task["name"], "state", state) + if task["metadata"]["node_type"].upper() == "WORKGRAPH": + # expose the outputs of all the tasks in the workgraph + task["results"] = {} + outgoing = node.base.links.get_outgoing() + for link in outgoing.all(): + if isinstance(link.node, ProcessNode) and getattr( + link.node, "process_state", False + ): + task["results"][link.link_label] = link.node.outputs + else: + task["results"] = node.outputs + # self.ctx.new_data[name] = task["results"] + self.set_task_state_info(task["name"], "state", "FINISHED") + self.task_set_context(name) + self.report(f"Task: {name} finished.") + # all other states are considered as failed + else: + task["results"] = node.outputs + # self.ctx.new_data[name] = task["results"] + self.set_task_state_info(task["name"], "state", "FAILED") + # set child tasks state to SKIPPED + self.set_tasks_state( + self.ctx.connectivity["child_node"][name], "SKIPPED" + ) + self.report(f"Task: {name} failed.") + self.run_error_handlers(name) + elif isinstance(node, orm.Data): + task["results"] = {task["outputs"][0]["name"]: node} + self.set_task_state_info(task["name"], "state", "FINISHED") + self.task_set_context(name) + self.report(f"Task: {name} finished.") + else: + task["results"] = None + + self.update_parent_task_state(name) + + def update_parent_task_state(self, name: str) -> None: + """Update parent task state.""" + parent_task = self.ctx.tasks[name].get("parent_task", None) + if parent_task: + if self.ctx.tasks[parent_task]["metadata"]["node_type"].upper() == "WHILE": + self.update_while_task_state(parent_task) + + def update_while_task_state(self, name: str) -> None: + """Update while task state.""" + finished, _ = self.is_while_task_finished(name) + + if finished: + should_run = self.should_run_while_task(name) + if should_run: + self.ctx.tasks[name]["execution_count"] += 1 + self.reset_task(name) + else: + self.set_task_state_info(name, "state", "FINISHED") + + def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: + """Check if the while task should run.""" + # check the conditions of the while task + task = self.ctx.tasks[name] + not_excess_max_iterations = ( + self.ctx.tasks[name]["execution_count"] + < self.ctx.tasks[name]["properties"]["max_iterations"]["value"] + ) + conditions = [not_excess_max_iterations] + for condition in task["properties"]["conditions"]["value"]: + value = get_nested_dict(self.ctx, condition) + conditions.append(value) + return False not in conditions + + def is_while_task_finished(self, name: str) -> tuple[bool, t.Any]: + """Check if the while task is finished.""" + task = self.ctx.tasks[name] + finished = True + for name in task["properties"]["tasks"]["value"]: + if self.get_task_state_info(name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + finished = False + break + return finished, None def run_error_handlers(self, task_name: str) -> None: """Run error handler.""" @@ -838,14 +874,8 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None if task["metadata"]["node_type"].upper() == "NODE": print("task type: node.") results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) - self.set_task_state_info(task["name"], "process", results) - task["results"] = {task["outputs"][0]["name"]: results} - self.ctx.input_tasks[name] = results - self.set_task_state_info(name, "state", "FINISHED") - self.task_set_context(name) - # ValueError: attempted to add an input link after the process node was already stored. - # self.node.base.links.add_incoming(results, "INPUT_WORK", name) - self.report(f"Task: {name} finished.") + self.set_task_state_info(name, "process", results) + self.update_task_state(name) if continue_workgraph: self.continue_workgraph() elif task["metadata"]["node_type"].upper() == "DATA": @@ -853,12 +883,9 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None for key in self.ctx.tasks[name]["metadata"]["args"]: kwargs.pop(key, None) results = create_data_node(executor, args, kwargs) - task["results"] = {task["outputs"][0]["name"]: results} - self.set_task_state_info(task["name"], "process", results) + self.set_task_state_info(name, "process", results) + self.update_task_state(name) self.ctx.new_data[name] = results - self.set_task_state_info(name, "state", "FINISHED") - self.task_set_context(name) - self.report(f"Task: {name} finished.") if continue_workgraph: self.continue_workgraph() elif task["metadata"]["node_type"].upper() in [ @@ -877,19 +904,13 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None executor, **kwargs, **var_kwargs ) process.label = name - # only one output - if isinstance(results, orm.Data): - results = {task["outputs"][0]["name"]: results} - task["results"] = results # print("results: ", results) - self.set_task_state_info(task["name"], "process", process) - self.set_task_state_info(name, "state", "FINISHED") - self.task_set_context(name) - self.report(f"Task: {name} finished.") + self.set_task_state_info(name, "process", process) + self.update_task_state(name) except Exception as e: print(e) self.report(e) - self.set_task_state_info(task["name"], "state", "FAILED") + self.set_task_state_info(name, "state", "FAILED") # set child state to FAILED self.set_tasks_state( self.ctx.connectivity["child_node"][name], "SKIPPED" @@ -920,7 +941,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None process = self.submit(executor, **kwargs) self.set_task_state_info(name, "state", "RUNNING") process.label = name - self.set_task_state_info(task["name"], "process", process) + self.set_task_state_info(name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: print("task type: graph_builder.") @@ -931,7 +952,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None wg.save(metadata={"call_link_label": name}) print("submit workgraph: ") process = self.submit(wg.process_inited) - self.set_task_state_info(task["name"], "process", process) + self.set_task_state_info(name, "process", process) self.set_task_state_info(name, "state", "RUNNING") self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["WORKGRAPH"]: @@ -940,7 +961,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None inputs, _ = prepare_for_workgraph_task(task, kwargs) print("submit workgraph: ") process = self.submit(WorkGraphEngine, inputs=inputs) - self.set_task_state_info(task["name"], "process", process) + self.set_task_state_info(name, "process", process) self.set_task_state_info(name, "state", "RUNNING") self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["PYTHONJOB"]: @@ -964,7 +985,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None process = self.submit(PythonJob, **inputs) self.set_task_state_info(name, "state", "RUNNING") process.label = name - self.set_task_state_info(task["name"], "process", process) + self.set_task_state_info(name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]: from aiida_shell.calculations.shell import ShellJob @@ -986,8 +1007,11 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None process = self.submit(ShellJob, **inputs) self.set_task_state_info(name, "state", "RUNNING") process.label = name - self.set_task_state_info(task["name"], "process", process) + self.set_task_state_info(name, "process", process) self.to_context(**{name: process}) + elif task["metadata"]["node_type"].upper() in ["WHILE"]: + self.set_task_state_info(name, "state", "RUNNING") + self.continue_workgraph() elif task["metadata"]["node_type"].upper() in ["NORMAL"]: print("Task type: Normal.") # normal function does not have a process @@ -1016,6 +1040,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None self.set_task_state_info(name, "state", "FINISHED") self.task_set_context(name) self.report(f"Task: {name} finished.") + self.update_parent_task_state(name) if continue_workgraph: self.continue_workgraph() # print("result from node: ", task["results"]) @@ -1034,7 +1059,6 @@ def get_inputs( t.Dict[str, t.Any], ]: """Get input based on the links.""" - from aiida_workgraph.utils import get_nested_dict args = [] args_dict = {} @@ -1114,7 +1138,6 @@ def get_inputs( def update_context_variable(self, value: t.Any) -> t.Any: # replace context variables - from aiida_workgraph.utils import get_nested_dict """Get value from context.""" if isinstance(value, dict): @@ -1138,70 +1161,79 @@ def task_set_context(self, name: str) -> None: result = self.ctx.tasks[name]["results"][key] update_nested_dict(self.ctx, value, result) - def check_task_state(self, name: str) -> None: - """Check task states. - - - if all input tasks finished, launch task - - if task is a scatter task, check if all scattered tasks finished + def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: + """Check if the task ready to run. + For normal node, we need to check its parent nodes: + - wait tasks + - input tasks + For task inside a while zone, we need to check if the while task is ready. + - while parent task + For while task, we need to check its child tasks, and the conditions. + - all input tasks of the while task are ready + - while conditions """ - # print(f" Check task {name} state: ") - if self.get_task_state_info(name, "state") in ["PLANNED", "WAITING"]: - ready, output = self.check_parent_state(name) - if ready: - # print(f" Task {name} is ready to launch.") - self.ctx.msgs.append(f"task,{name}:action:LAUNCH") # noqa E231 - elif self.get_task_state_info(name, "state") in ["SCATTERED"]: - state, action = self.check_scattered_state(name) - self.ctx.msgs.append(f"task,{name}:state:{state}") # noqa E231 - else: - # print(f" Task {name} is in state {self.ctx.tasks[name]['state']}") - pass - - def check_parent_state(self, name: str) -> t.Tuple[bool, t.Optional[str]]: task = self.ctx.tasks[name] - inputs = task.get("inputs", None) + inputs = task.get("inputs", []) wait_tasks = self.ctx.tasks[name].get("wait", []) - # print(" wait_tasks: ", wait_tasks) - ready = True - if inputs is None and len(wait_tasks) == 0: - return ready - else: - # check the wait task first - for task_name in wait_tasks: - # in case the task is removed - if task_name not in self.ctx.tasks: - continue - if self.get_task_state_info(task_name, "state") not in [ + parent_task = self.ctx.tasks[name].get("parent_task", None) + # wait, inputs, parent_task, child_tasks, conditions + parent_states = [True, True, True, True, True] + # if the task belongs to a while zoone + if parent_task: + state = self.get_task_state_info(parent_task, "state") + if state not in ["RUNNING"]: + parent_states[2] = False + # if the task is a while task + if task["metadata"]["node_type"].upper() == "WHILE": + # check if the all the child tasks are ready + for child_task_name in self.ctx.connectivity["while"][name]["input_tasks"]: + ready, parent_states = self.is_task_ready_to_run(child_task_name) + if not ready: + parent_states[3] = False + break + # check the conditions of the while task + parent_states[4] = self.should_run_while_task(name) + # check the wait task first + for task_name in wait_tasks: + # in case the task is removed + if task_name not in self.ctx.tasks: + continue + if self.get_task_state_info(task_name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + parent_states[0] = False + break + for input in inputs: + # print(" input, ", input["from_node"], self.ctx.tasks[input["from_node"]]["state"]) + for link in input["links"]: + if self.get_task_state_info(link["from_node"], "state") not in [ "FINISHED", "SKIPPED", "FAILED", ]: - ready = False - return ready, f"Task {name} wait for {task_name}" - for input in inputs: - # print(" input, ", input["from_node"], self.ctx.tasks[input["from_node"]]["state"]) - for link in input["links"]: - if self.get_task_state_info(link["from_node"], "state") not in [ + parent_states[1] = False + break + # check if the input task belong to a while task, and the while task is ready + parent_task = self.ctx.tasks[link["from_node"]].get("parent_task", None) + # if the task itself does not belong to the while task + if ( + parent_task + and name + not in self.ctx.tasks[parent_task]["properties"]["tasks"]["value"] + ): + state = self.get_task_state_info(parent_task, "state") + if state not in [ "FINISHED", "SKIPPED", "FAILED", ]: - ready = False - return ( - ready, - f"{name}, input: {link['from_node']} is {self.ctx.tasks[link['from_node']]['state']}", - ) - return ready, None - - # def expose_graph_build_outputs(self, name): - # # print("expose_graph_build_outputs") - # outputs = {} - # process = self.ctx.tasks[name]["process"] - # outgoing = process.base.links.get_outgoing() - # for output in self.ctx.tasks[name]["group_outputs"]: - # node = outgoing.get_node_by_label(output[0]) - # outputs[output[2]] = getattr(node.outputs, output[1]) - # return outputs + parent_states[1] = False + break + # print("is task ready to run: ", name, all(parent_states), parent_states) + return all(parent_states), parent_states + def reset(self) -> None: print("Reset") self.ctx._execution_count += 1 @@ -1296,8 +1328,6 @@ def message_receive( def finalize(self) -> t.Optional[ExitCode]: """""" - from aiida_workgraph.utils import get_nested_dict, update_nested_dict - # expose outputs of the workgraph group_outputs = {} print("workgraph outputs: ", self.ctx.workgraph["metadata"]["group_outputs"]) diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index 5d93c696..e4f3ee7c 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -29,7 +29,7 @@ class Task(GraphNode): def __init__( self, context_mapping: Optional[List[Any]] = None, - wait: List[Union[str, GraphNode]] = [], + wait: List[Union[str, GraphNode]] = None, process: Optional[aiida.orm.ProcessNode] = None, pk: Optional[int] = None, **kwargs: Any, diff --git a/aiida_workgraph/tasks/__init__.py b/aiida_workgraph/tasks/__init__.py index 9dca0ab0..ef890f2a 100644 --- a/aiida_workgraph/tasks/__init__.py +++ b/aiida_workgraph/tasks/__init__.py @@ -1,5 +1,5 @@ from node_graph.utils import get_entries -from .builtin import AiiDAGather, AiiDAToCtx, AiiDAFromCtx +from .builtin import AiiDAGather, AiiDAToCtx, AiiDAFromCtx, While from .test import ( AiiDAInt, AiiDAFloat, @@ -20,6 +20,7 @@ ) task_list = [ + While, AiiDAGather, AiiDAToCtx, AiiDAFromCtx, diff --git a/aiida_workgraph/tasks/builtin.py b/aiida_workgraph/tasks/builtin.py index f3c37caf..b4a29d44 100644 --- a/aiida_workgraph/tasks/builtin.py +++ b/aiida_workgraph/tasks/builtin.py @@ -2,6 +2,26 @@ from aiida_workgraph.task import Task +class While(Task): + """While""" + + identifier = "While" + name = "While" + node_type = "WHILE" + catalog = "Control" + kwargs = ["max_iterations", "conditions", "tasks"] + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + inp = self.inputs.new("Any", "_wait") + inp.link_limit = 100000 + self.inputs.new("Int", "max_iterations") + self.inputs.new("Any", "tasks") + self.inputs.new("Any", "conditions") + self.outputs.new("Any", "_wait") + + class AiiDAGather(Task): """AiiDAGather""" diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index a8501459..3ed54bde 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -27,6 +27,8 @@ def get_executor(data: Dict[str, Any]) -> Union[Process, Any]: executor = CalculationFactory(data["name"]) elif type == "DataFactory": executor = DataFactory(data["name"]) + elif data["name"] == "" and data["path"] == "": + executor = None else: module = importlib.import_module("{}".format(data.get("path", ""))) executor = getattr(module, data["name"]) @@ -52,16 +54,20 @@ def create_data_node(executor: orm.Data, args: list, kwargs: dict) -> orm.Node: return data_node -def get_nested_dict(d: Dict, name: str, allow_none: bool = False) -> Any: - """ +def get_nested_dict(d: Dict, name: str, **kwargs) -> Any: + """Get the value from a nested dictionary. + If default is provided, return the default value if the key is not found. + Otherwise, raise ValueError. + For example: + d = {"base": {"pw": {"parameters": 2}}} name = "base.pw.parameters" """ keys = name.split(".") current = d for key in keys: if key not in current: - if allow_none: - return None + if "default" in kwargs: + return kwargs.get("default") else: raise ValueError(f"{name} not exist in {d}") current = current[key] @@ -628,10 +634,14 @@ def workgraph_to_short_json( ] wgdata_short["nodes"][name] = { "label": task["name"], + "node_type": task["metadata"]["node_type"], "inputs": inputs, "outputs": [], "position": task["position"], } + # Add properties to nodes if it is a While task + if task["metadata"]["node_type"].upper() == "WHILE": + wgdata_short["nodes"][name]["properties"] = task["properties"] # Add links to nodes for link in wgdata["links"]: wgdata_short["nodes"][link["to_node"]]["inputs"].append( diff --git a/aiida_workgraph/utils/analysis.py b/aiida_workgraph/utils/analysis.py index 8d5c023d..b1c14521 100644 --- a/aiida_workgraph/utils/analysis.py +++ b/aiida_workgraph/utils/analysis.py @@ -60,6 +60,7 @@ def save(self) -> None: """ self.build_task_link() self.build_connectivity() + self.assign_while_zone() if self.exist_in_db() or self.restart_process is not None: new_tasks, modified_tasks, update_metadata = self.check_diff( self.restart_process @@ -94,6 +95,28 @@ def build_task_link(self) -> None: to_socket["links"].append(link) from_socket["links"].append(link) + def assign_while_zone(self) -> None: + """Assign while zone for each task.""" + self.wgdata["connectivity"]["while"] = {} + # assign parent_task for each task + for name, task in self.wgdata["tasks"].items(): + if task["metadata"]["node_type"].upper() == "WHILE": + input_tasks = [] + for name in task["properties"]["tasks"]["value"]: + self.wgdata["tasks"][name]["parent_task"] = task["name"] + # find all the input tasks which outside the while zone + for input in self.wgdata["tasks"][name]["inputs"]: + for link in input["links"]: + if ( + link["from_node"] + not in task["properties"]["tasks"]["value"] + ): + input_tasks.append(link["from_node"]) + task["execution_count"] = 0 + self.wgdata["connectivity"]["while"][task["name"]] = { + "input_tasks": input_tasks + } + def insert_workgraph_to_db(self) -> None: """Save a new workgraph in the database. diff --git a/aiida_workgraph/web/frontend/package-lock.json b/aiida_workgraph/web/frontend/package-lock.json index 8e02210e..ec8aae70 100644 --- a/aiida_workgraph/web/frontend/package-lock.json +++ b/aiida_workgraph/web/frontend/package-lock.json @@ -21,7 +21,7 @@ "@types/three": "^0.156.0", "antd": "^5.12.1", "d3": "^7.8.5", - "elkjs": "^0.8.2", + "elkjs": "^0.9.2", "fs": "^0.0.1-security", "mathjs": "^12.3.0", "moment": "^2.29.4", @@ -34,18 +34,18 @@ "react-scripts": "5.0.1", "react-syntax-highlighter": "^15.5.0", "react-toastify": "^9.1.3", - "rete": "^2.0.2", - "rete-area-3d-plugin": "^2.0.3", - "rete-area-plugin": "^2.0.1", - "rete-auto-arrange-plugin": "^2.0.0", - "rete-connection-plugin": "^2.0.0", - "rete-connection-reroute-plugin": "^2.0.0", - "rete-context-menu-plugin": "^2.0.0", - "rete-minimap-plugin": "^2.0.1", - "rete-react-plugin": "^2.0.4", - "rete-readonly-plugin": "^2.0.0", - "rete-render-utils": "^2.0.1", - "styled-components": "^5.3.11", + "rete": "2.0.3", + "rete-area-plugin": "2.0.3", + "rete-auto-arrange-plugin": "2.0.1", + "rete-connection-plugin": "2.0.1", + "rete-connection-reroute-plugin": "2.0.0", + "rete-context-menu-plugin": "2.0.0", + "rete-minimap-plugin": "2.0.1", + "rete-react-plugin": "2.0.5", + "rete-readonly-plugin": "2.0.0", + "rete-render-utils": "2.0.2", + "rete-scopes-plugin": "2.1.0", + "styled-components": "6.1.8", "three": "^0.156.1", "typescript": "^4.9.5", "vis-timeline": "^7.7.3", @@ -2437,11 +2437,6 @@ "resolved": "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.8.1.tgz", "integrity": "sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==" }, - "node_modules/@emotion/stylis": { - "version": "0.8.5", - "resolved": "https://registry.npmjs.org/@emotion/stylis/-/stylis-0.8.5.tgz", - "integrity": "sha512-h6KtPihKFn3T9fuIrwvXXUOwlx3rfUvfZIcP5a6rh8Y7zjE3O06hT5Ss4S/YI1AYhuZ1kjaE/5EaOOI2NqSylQ==" - }, "node_modules/@emotion/unitless": { "version": "0.7.5", "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.7.5.tgz", @@ -4586,6 +4581,11 @@ "csstype": "^3.0.2" } }, + "node_modules/@types/stylis": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@types/stylis/-/stylis-4.2.0.tgz", + "integrity": "sha512-n4sx2bqL0mW1tvDf/loQ+aMX7GQD3lc3fkCMC55VFNDu/vBOabO+LTIeXKM14xK0ppk5TUGcWRjiSpIlUpghKw==" + }, "node_modules/@types/testing-library__jest-dom": { "version": "5.14.9", "resolved": "https://registry.npmjs.org/@types/testing-library__jest-dom/-/testing-library__jest-dom-5.14.9.tgz", @@ -5782,21 +5782,6 @@ "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" } }, - "node_modules/babel-plugin-styled-components": { - "version": "2.1.4", - "resolved": "https://registry.npmjs.org/babel-plugin-styled-components/-/babel-plugin-styled-components-2.1.4.tgz", - "integrity": "sha512-Xgp9g+A/cG47sUyRwwYxGM4bR/jDRg5N6it/8+HxCnbT5XNKSKDT9xm4oag/osgqjC2It/vH0yXsomOG6k558g==", - "dependencies": { - "@babel/helper-annotate-as-pure": "^7.22.5", - "@babel/helper-module-imports": "^7.22.5", - "@babel/plugin-syntax-jsx": "^7.22.5", - "lodash": "^4.17.21", - "picomatch": "^2.3.1" - }, - "peerDependencies": { - "styled-components": ">= 2" - } - }, "node_modules/babel-plugin-transform-react-remove-prop-types": { "version": "0.4.24", "resolved": "https://registry.npmjs.org/babel-plugin-transform-react-remove-prop-types/-/babel-plugin-transform-react-remove-prop-types-0.4.24.tgz", @@ -7856,9 +7841,9 @@ } }, "node_modules/elkjs": { - "version": "0.8.2", - "resolved": "https://registry.npmjs.org/elkjs/-/elkjs-0.8.2.tgz", - "integrity": "sha512-L6uRgvZTH+4OF5NE/MBbzQx/WYpru1xCBE9respNj6qznEewGUIfhzmm7horWWxbNO2M0WckQypGctR8lH79xQ==" + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/elkjs/-/elkjs-0.9.3.tgz", + "integrity": "sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==" }, "node_modules/emittery": { "version": "0.8.1", @@ -9800,6 +9785,7 @@ "version": "3.3.2", "resolved": "https://registry.npmjs.org/hoist-non-react-statics/-/hoist-non-react-statics-3.3.2.tgz", "integrity": "sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw==", + "dev": true, "dependencies": { "react-is": "^16.7.0" } @@ -9807,7 +9793,8 @@ "node_modules/hoist-non-react-statics/node_modules/react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", - "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "dev": true }, "node_modules/hoopy": { "version": "0.1.4", @@ -14171,9 +14158,9 @@ } }, "node_modules/postcss": { - "version": "8.4.32", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.32.tgz", - "integrity": "sha512-D/kj5JNu6oo2EIy+XL/26JEDTlIbB8hw85G8StOE6L74RQAVVP5rej6wxCNqyMbR4RkPfqvezVbPw81Ngd6Kcw==", + "version": "8.4.38", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.38.tgz", + "integrity": "sha512-Wglpdk03BSfXkHoQa3b/oulrotAkwrlLDRSOb9D0bN86FdRyE9lppSp33aHNPgBa0JKCoB+drFLZkQoRRYae5A==", "funding": [ { "type": "opencollective", @@ -14191,7 +14178,7 @@ "dependencies": { "nanoid": "^3.3.7", "picocolors": "^1.0.0", - "source-map-js": "^1.0.2" + "source-map-js": "^1.2.0" }, "engines": { "node": "^10 || ^12 || >=14" @@ -16895,32 +16882,18 @@ } }, "node_modules/rete": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/rete/-/rete-2.0.2.tgz", - "integrity": "sha512-VZhyWl0E3dzcRRiN5OIVK4CIVcADZHO4XCFs85fjoi4ZYCPcB3P608wzq5MbdlYOfptPyuvKOrRqgFCtILdKIw==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/rete/-/rete-2.0.3.tgz", + "integrity": "sha512-/xzcyEBhVXhMZVZHElnYaLKOmTEuwlnul9Wfjvxw5sdl/+6Nqn2nyqIaW4koefrFpIWZy9aitnjnP3zeCMVDuw==", "hasInstallScript": true, "dependencies": { "@babel/runtime": "^7.21.0" } }, - "node_modules/rete-area-3d-plugin": { - "version": "2.0.3", - "resolved": "https://registry.npmjs.org/rete-area-3d-plugin/-/rete-area-3d-plugin-2.0.3.tgz", - "integrity": "sha512-1Pc6jfmtghj0pUKz3TVD3wVOYS8nnqJvaZJ9zGqPYufTsAqdV6idOq2K1AWUayoeqleTCymY2YZ4Nj/0TmCTDA==", - "dependencies": { - "@babel/runtime": "^7.21.0" - }, - "peerDependencies": { - "rete": "^2.0.1", - "rete-area-plugin": "^2.0.0", - "rete-render-utils": "^2.0.0", - "three": ">= 0.152.2 < 0.157.0" - } - }, "node_modules/rete-area-plugin": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/rete-area-plugin/-/rete-area-plugin-2.0.1.tgz", - "integrity": "sha512-H7IGv2Tfm1Tk928Hl6O9pS3JCmKboCFY4xqGm2TCXvzVAlHvmUV7mtkxlT1fCAZQMNin6ktwthCdQ455euHDgQ==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/rete-area-plugin/-/rete-area-plugin-2.0.3.tgz", + "integrity": "sha512-RWHCoMh0HJ7arnBEaU51j1J4AuK+qWYR7tadVom8uiwkXK7Xv9VeJkKy8xWT+Ckw+g2DlsI4SWrQkMG7wfWtug==", "dependencies": { "@babel/runtime": "^7.21.0" }, @@ -16929,9 +16902,9 @@ } }, "node_modules/rete-auto-arrange-plugin": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/rete-auto-arrange-plugin/-/rete-auto-arrange-plugin-2.0.0.tgz", - "integrity": "sha512-wyrJ+DW94J1E4ceTX6XVHt7lnp0eNXdwHKnuREY8ePl3BO8buLutvVpuMFwJEVW6AbnUVcKY38/huAwp4wDGoQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/rete-auto-arrange-plugin/-/rete-auto-arrange-plugin-2.0.1.tgz", + "integrity": "sha512-vHxsrI+l3wxZzxPxG7hcgUbacXQfEc1ZEE28r08O1kEy0kUyNkJR5OeCiSizZ4VucsDmu21WUtFVa1rl5h+e1A==", "dependencies": { "@babel/runtime": "^7.21.0" }, @@ -16943,9 +16916,9 @@ } }, "node_modules/rete-connection-plugin": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/rete-connection-plugin/-/rete-connection-plugin-2.0.0.tgz", - "integrity": "sha512-8M+UC6gcWwTi0PEICprmCaoGxwEA4x42z0ywx3O5NQSILvkhWcvQXzHcyvwGx/LTsJT/UPzNjCfixu/TbRyTEw==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/rete-connection-plugin/-/rete-connection-plugin-2.0.1.tgz", + "integrity": "sha512-KE1IcjeOQtHgkByODtWS5hgRJDGhR3Z9sZyJAEd7YMgI6o+KUIflcNjbkvhJvPeIAv6WlEAh7ZkwdLhF9bkr4w==", "dependencies": { "@babel/runtime": "^7.21.0" }, @@ -16992,9 +16965,9 @@ } }, "node_modules/rete-react-plugin": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/rete-react-plugin/-/rete-react-plugin-2.0.4.tgz", - "integrity": "sha512-t+rsaZ6wUFVbO1krfzUl8GInHI+V9Zp8q2fLj8NnDLVKRxQzq1W0sfpBvBK262dkt2TEZvHyeX9M+eqTlaowHQ==", + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/rete-react-plugin/-/rete-react-plugin-2.0.5.tgz", + "integrity": "sha512-xoui2+Mv6iqpRTxccAu3MZv3+l5LYk4AmtqGWEqlCIwZjplrsAoVeOLYq235spwf+vd3ujzapnycEzYF9aj3cA==", "dependencies": { "@babel/runtime": "^7.21.0", "usehooks-ts": "^2.9.1" @@ -17022,9 +16995,9 @@ } }, "node_modules/rete-render-utils": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/rete-render-utils/-/rete-render-utils-2.0.1.tgz", - "integrity": "sha512-mzNVADCE1iV0AlkVyz1Pai34GG55VYBIWWOv9MqHUl7jlnpNIIkx+hARIc3wgabcye46IdswQPUApuARhvjbmA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/rete-render-utils/-/rete-render-utils-2.0.2.tgz", + "integrity": "sha512-f4kj+dFL5QrebOkjCdwi8htHteDFbKyqrVdFDToEUvGuGod1sdLeKxOPBOhwyYDB4Zxd3Cq84I93vD2etrTL9g==", "dependencies": { "@babel/runtime": "^7.21.0" }, @@ -17033,6 +17006,18 @@ "rete-area-plugin": "^2.0.0" } }, + "node_modules/rete-scopes-plugin": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/rete-scopes-plugin/-/rete-scopes-plugin-2.1.0.tgz", + "integrity": "sha512-qpbTvpPlKb52vPX56XA41RjK4JARWE07k3RnjvM4sNYfrdEszX8m0ZaTAeWGaslEuVdlY/rAOl6NxCbiU6E0sg==", + "dependencies": { + "@babel/runtime": "^7.21.0" + }, + "peerDependencies": { + "rete": "^2.0.1", + "rete-area-plugin": "^2.0.0" + } + }, "node_modules/retry": { "version": "0.13.1", "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", @@ -17614,9 +17599,9 @@ } }, "node_modules/source-map-js": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", - "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz", + "integrity": "sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==", "engines": { "node": ">=0.10.0" } @@ -18043,23 +18028,22 @@ } }, "node_modules/styled-components": { - "version": "5.3.11", - "resolved": "https://registry.npmjs.org/styled-components/-/styled-components-5.3.11.tgz", - "integrity": "sha512-uuzIIfnVkagcVHv9nE0VPlHPSCmXIUGKfJ42LNjxCCTDTL5sgnJ8Z7GZBq0EnLYGln77tPpEpExt2+qa+cZqSw==", - "dependencies": { - "@babel/helper-module-imports": "^7.0.0", - "@babel/traverse": "^7.4.5", - "@emotion/is-prop-valid": "^1.1.0", - "@emotion/stylis": "^0.8.4", - "@emotion/unitless": "^0.7.4", - "babel-plugin-styled-components": ">= 1.12.0", - "css-to-react-native": "^3.0.0", - "hoist-non-react-statics": "^3.0.0", - "shallowequal": "^1.1.0", - "supports-color": "^5.5.0" + "version": "6.1.8", + "resolved": "https://registry.npmjs.org/styled-components/-/styled-components-6.1.8.tgz", + "integrity": "sha512-PQ6Dn+QxlWyEGCKDS71NGsXoVLKfE1c3vApkvDYS5KAK+V8fNWGhbSUEo9Gg2iaID2tjLXegEW3bZDUGpofRWw==", + "dependencies": { + "@emotion/is-prop-valid": "1.2.1", + "@emotion/unitless": "0.8.0", + "@types/stylis": "4.2.0", + "css-to-react-native": "3.2.0", + "csstype": "3.1.2", + "postcss": "8.4.31", + "shallowequal": "1.1.0", + "stylis": "4.3.1", + "tslib": "2.5.0" }, "engines": { - "node": ">=10" + "node": ">= 16" }, "funding": { "type": "opencollective", @@ -18067,10 +18051,56 @@ }, "peerDependencies": { "react": ">= 16.8.0", - "react-dom": ">= 16.8.0", - "react-is": ">= 16.8.0" + "react-dom": ">= 16.8.0" + } + }, + "node_modules/styled-components/node_modules/@emotion/unitless": { + "version": "0.8.0", + "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.8.0.tgz", + "integrity": "sha512-VINS5vEYAscRl2ZUDiT3uMPlrFQupiKgHz5AA4bCH1miKBg4qtwkim1qPmJj/4WG6TreYMY111rEFsjupcOKHw==" + }, + "node_modules/styled-components/node_modules/csstype": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.2.tgz", + "integrity": "sha512-I7K1Uu0MBPzaFKg4nI5Q7Vs2t+3gWWW648spaF+Rg7pI9ds18Ugn+lvg4SHczUdKlHI5LWBXyqfS8+DufyBsgQ==" + }, + "node_modules/styled-components/node_modules/postcss": { + "version": "8.4.31", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz", + "integrity": "sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "dependencies": { + "nanoid": "^3.3.6", + "picocolors": "^1.0.0", + "source-map-js": "^1.0.2" + }, + "engines": { + "node": "^10 || ^12 || >=14" } }, + "node_modules/styled-components/node_modules/stylis": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.1.tgz", + "integrity": "sha512-EQepAV+wMsIaGVGX1RECzgrcqRRU/0sYOHkeLsZ3fzHaHXZy4DaOOX0vOlGQdlsjkh3mFHAIlVimpwAs4dslyQ==" + }, + "node_modules/styled-components/node_modules/tslib": { + "version": "2.5.0", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.5.0.tgz", + "integrity": "sha512-336iVw3rtn2BUK7ORdIAHTyxHGRIHVReokCR3XjbckJMK7ms8FysBfhLR8IXnAgy7T0PTPNBWKiH514FOW/WSg==" + }, "node_modules/stylehacks": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/stylehacks/-/stylehacks-5.1.1.tgz", @@ -18087,9 +18117,9 @@ } }, "node_modules/stylis": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.0.tgz", - "integrity": "sha512-E87pIogpwUsUwXw7dNyU4QDjdgVMy52m+XEOPEKUn161cCzWjjhPSQhByfd1CcNvrOLnXQ6OnnZDwnJrz/Z4YQ==" + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.2.tgz", + "integrity": "sha512-bhtUjWd/z6ltJiQwg0dUfxEJ+W+jdqQd8TbWLWyeIJHlnsqmGLRFFd8e5mA0AZi/zx90smXRlN66YMTcaSFifg==" }, "node_modules/sucrase": { "version": "3.34.0", @@ -21603,11 +21633,6 @@ "resolved": "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.8.1.tgz", "integrity": "sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA==" }, - "@emotion/stylis": { - "version": "0.8.5", - "resolved": "https://registry.npmjs.org/@emotion/stylis/-/stylis-0.8.5.tgz", - "integrity": "sha512-h6KtPihKFn3T9fuIrwvXXUOwlx3rfUvfZIcP5a6rh8Y7zjE3O06hT5Ss4S/YI1AYhuZ1kjaE/5EaOOI2NqSylQ==" - }, "@emotion/unitless": { "version": "0.7.5", "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.7.5.tgz", @@ -23212,6 +23237,11 @@ "csstype": "^3.0.2" } }, + "@types/stylis": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@types/stylis/-/stylis-4.2.0.tgz", + "integrity": "sha512-n4sx2bqL0mW1tvDf/loQ+aMX7GQD3lc3fkCMC55VFNDu/vBOabO+LTIeXKM14xK0ppk5TUGcWRjiSpIlUpghKw==" + }, "@types/testing-library__jest-dom": { "version": "5.14.9", "resolved": "https://registry.npmjs.org/@types/testing-library__jest-dom/-/testing-library__jest-dom-5.14.9.tgz", @@ -24094,18 +24124,6 @@ "@babel/helper-define-polyfill-provider": "^0.4.3" } }, - "babel-plugin-styled-components": { - "version": "2.1.4", - "resolved": "https://registry.npmjs.org/babel-plugin-styled-components/-/babel-plugin-styled-components-2.1.4.tgz", - "integrity": "sha512-Xgp9g+A/cG47sUyRwwYxGM4bR/jDRg5N6it/8+HxCnbT5XNKSKDT9xm4oag/osgqjC2It/vH0yXsomOG6k558g==", - "requires": { - "@babel/helper-annotate-as-pure": "^7.22.5", - "@babel/helper-module-imports": "^7.22.5", - "@babel/plugin-syntax-jsx": "^7.22.5", - "lodash": "^4.17.21", - "picomatch": "^2.3.1" - } - }, "babel-plugin-transform-react-remove-prop-types": { "version": "0.4.24", "resolved": "https://registry.npmjs.org/babel-plugin-transform-react-remove-prop-types/-/babel-plugin-transform-react-remove-prop-types-0.4.24.tgz", @@ -25605,9 +25623,9 @@ } }, "elkjs": { - "version": "0.8.2", - "resolved": "https://registry.npmjs.org/elkjs/-/elkjs-0.8.2.tgz", - "integrity": "sha512-L6uRgvZTH+4OF5NE/MBbzQx/WYpru1xCBE9respNj6qznEewGUIfhzmm7horWWxbNO2M0WckQypGctR8lH79xQ==" + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/elkjs/-/elkjs-0.9.3.tgz", + "integrity": "sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==" }, "emittery": { "version": "0.8.1", @@ -27017,6 +27035,7 @@ "version": "3.3.2", "resolved": "https://registry.npmjs.org/hoist-non-react-statics/-/hoist-non-react-statics-3.3.2.tgz", "integrity": "sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw==", + "dev": true, "requires": { "react-is": "^16.7.0" }, @@ -27024,7 +27043,8 @@ "react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", - "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==" + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "dev": true } } }, @@ -30177,13 +30197,13 @@ } }, "postcss": { - "version": "8.4.32", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.32.tgz", - "integrity": "sha512-D/kj5JNu6oo2EIy+XL/26JEDTlIbB8hw85G8StOE6L74RQAVVP5rej6wxCNqyMbR4RkPfqvezVbPw81Ngd6Kcw==", + "version": "8.4.38", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.38.tgz", + "integrity": "sha512-Wglpdk03BSfXkHoQa3b/oulrotAkwrlLDRSOb9D0bN86FdRyE9lppSp33aHNPgBa0JKCoB+drFLZkQoRRYae5A==", "requires": { "nanoid": "^3.3.7", "picocolors": "^1.0.0", - "source-map-js": "^1.0.2" + "source-map-js": "^1.2.0" } }, "postcss-attribute-case-insensitive": { @@ -31942,41 +31962,33 @@ "integrity": "sha512-/NtpHNDN7jWhAaQ9BvBUYZ6YTXsRBgfqWFWP7BZBaoMJO/I3G5OFzvTuWNlZC3aPjins1F+TNrLKsGbH4rfsRQ==" }, "rete": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/rete/-/rete-2.0.2.tgz", - "integrity": "sha512-VZhyWl0E3dzcRRiN5OIVK4CIVcADZHO4XCFs85fjoi4ZYCPcB3P608wzq5MbdlYOfptPyuvKOrRqgFCtILdKIw==", - "requires": { - "@babel/runtime": "^7.21.0" - } - }, - "rete-area-3d-plugin": { "version": "2.0.3", - "resolved": "https://registry.npmjs.org/rete-area-3d-plugin/-/rete-area-3d-plugin-2.0.3.tgz", - "integrity": "sha512-1Pc6jfmtghj0pUKz3TVD3wVOYS8nnqJvaZJ9zGqPYufTsAqdV6idOq2K1AWUayoeqleTCymY2YZ4Nj/0TmCTDA==", + "resolved": "https://registry.npmjs.org/rete/-/rete-2.0.3.tgz", + "integrity": "sha512-/xzcyEBhVXhMZVZHElnYaLKOmTEuwlnul9Wfjvxw5sdl/+6Nqn2nyqIaW4koefrFpIWZy9aitnjnP3zeCMVDuw==", "requires": { "@babel/runtime": "^7.21.0" } }, "rete-area-plugin": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/rete-area-plugin/-/rete-area-plugin-2.0.1.tgz", - "integrity": "sha512-H7IGv2Tfm1Tk928Hl6O9pS3JCmKboCFY4xqGm2TCXvzVAlHvmUV7mtkxlT1fCAZQMNin6ktwthCdQ455euHDgQ==", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/rete-area-plugin/-/rete-area-plugin-2.0.3.tgz", + "integrity": "sha512-RWHCoMh0HJ7arnBEaU51j1J4AuK+qWYR7tadVom8uiwkXK7Xv9VeJkKy8xWT+Ckw+g2DlsI4SWrQkMG7wfWtug==", "requires": { "@babel/runtime": "^7.21.0" } }, "rete-auto-arrange-plugin": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/rete-auto-arrange-plugin/-/rete-auto-arrange-plugin-2.0.0.tgz", - "integrity": "sha512-wyrJ+DW94J1E4ceTX6XVHt7lnp0eNXdwHKnuREY8ePl3BO8buLutvVpuMFwJEVW6AbnUVcKY38/huAwp4wDGoQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/rete-auto-arrange-plugin/-/rete-auto-arrange-plugin-2.0.1.tgz", + "integrity": "sha512-vHxsrI+l3wxZzxPxG7hcgUbacXQfEc1ZEE28r08O1kEy0kUyNkJR5OeCiSizZ4VucsDmu21WUtFVa1rl5h+e1A==", "requires": { "@babel/runtime": "^7.21.0" } }, "rete-connection-plugin": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/rete-connection-plugin/-/rete-connection-plugin-2.0.0.tgz", - "integrity": "sha512-8M+UC6gcWwTi0PEICprmCaoGxwEA4x42z0ywx3O5NQSILvkhWcvQXzHcyvwGx/LTsJT/UPzNjCfixu/TbRyTEw==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/rete-connection-plugin/-/rete-connection-plugin-2.0.1.tgz", + "integrity": "sha512-KE1IcjeOQtHgkByODtWS5hgRJDGhR3Z9sZyJAEd7YMgI6o+KUIflcNjbkvhJvPeIAv6WlEAh7ZkwdLhF9bkr4w==", "requires": { "@babel/runtime": "^7.21.0" } @@ -32006,9 +32018,9 @@ } }, "rete-react-plugin": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/rete-react-plugin/-/rete-react-plugin-2.0.4.tgz", - "integrity": "sha512-t+rsaZ6wUFVbO1krfzUl8GInHI+V9Zp8q2fLj8NnDLVKRxQzq1W0sfpBvBK262dkt2TEZvHyeX9M+eqTlaowHQ==", + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/rete-react-plugin/-/rete-react-plugin-2.0.5.tgz", + "integrity": "sha512-xoui2+Mv6iqpRTxccAu3MZv3+l5LYk4AmtqGWEqlCIwZjplrsAoVeOLYq235spwf+vd3ujzapnycEzYF9aj3cA==", "requires": { "@babel/runtime": "^7.21.0", "usehooks-ts": "^2.9.1" @@ -32023,9 +32035,17 @@ } }, "rete-render-utils": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/rete-render-utils/-/rete-render-utils-2.0.1.tgz", - "integrity": "sha512-mzNVADCE1iV0AlkVyz1Pai34GG55VYBIWWOv9MqHUl7jlnpNIIkx+hARIc3wgabcye46IdswQPUApuARhvjbmA==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/rete-render-utils/-/rete-render-utils-2.0.2.tgz", + "integrity": "sha512-f4kj+dFL5QrebOkjCdwi8htHteDFbKyqrVdFDToEUvGuGod1sdLeKxOPBOhwyYDB4Zxd3Cq84I93vD2etrTL9g==", + "requires": { + "@babel/runtime": "^7.21.0" + } + }, + "rete-scopes-plugin": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/rete-scopes-plugin/-/rete-scopes-plugin-2.1.0.tgz", + "integrity": "sha512-qpbTvpPlKb52vPX56XA41RjK4JARWE07k3RnjvM4sNYfrdEszX8m0ZaTAeWGaslEuVdlY/rAOl6NxCbiU6E0sg==", "requires": { "@babel/runtime": "^7.21.0" } @@ -32460,9 +32480,9 @@ "integrity": "sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==" }, "source-map-js": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", - "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==" + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.0.tgz", + "integrity": "sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==" }, "source-map-loader": { "version": "3.0.2", @@ -32779,20 +32799,51 @@ "requires": {} }, "styled-components": { - "version": "5.3.11", - "resolved": "https://registry.npmjs.org/styled-components/-/styled-components-5.3.11.tgz", - "integrity": "sha512-uuzIIfnVkagcVHv9nE0VPlHPSCmXIUGKfJ42LNjxCCTDTL5sgnJ8Z7GZBq0EnLYGln77tPpEpExt2+qa+cZqSw==", - "requires": { - "@babel/helper-module-imports": "^7.0.0", - "@babel/traverse": "^7.4.5", - "@emotion/is-prop-valid": "^1.1.0", - "@emotion/stylis": "^0.8.4", - "@emotion/unitless": "^0.7.4", - "babel-plugin-styled-components": ">= 1.12.0", - "css-to-react-native": "^3.0.0", - "hoist-non-react-statics": "^3.0.0", - "shallowequal": "^1.1.0", - "supports-color": "^5.5.0" + "version": "6.1.8", + "resolved": "https://registry.npmjs.org/styled-components/-/styled-components-6.1.8.tgz", + "integrity": "sha512-PQ6Dn+QxlWyEGCKDS71NGsXoVLKfE1c3vApkvDYS5KAK+V8fNWGhbSUEo9Gg2iaID2tjLXegEW3bZDUGpofRWw==", + "requires": { + "@emotion/is-prop-valid": "1.2.1", + "@emotion/unitless": "0.8.0", + "@types/stylis": "4.2.0", + "css-to-react-native": "3.2.0", + "csstype": "3.1.2", + "postcss": "8.4.31", + "shallowequal": "1.1.0", + "stylis": "4.3.1", + "tslib": "2.5.0" + }, + "dependencies": { + "@emotion/unitless": { + "version": "0.8.0", + "resolved": "https://registry.npmjs.org/@emotion/unitless/-/unitless-0.8.0.tgz", + "integrity": "sha512-VINS5vEYAscRl2ZUDiT3uMPlrFQupiKgHz5AA4bCH1miKBg4qtwkim1qPmJj/4WG6TreYMY111rEFsjupcOKHw==" + }, + "csstype": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.2.tgz", + "integrity": "sha512-I7K1Uu0MBPzaFKg4nI5Q7Vs2t+3gWWW648spaF+Rg7pI9ds18Ugn+lvg4SHczUdKlHI5LWBXyqfS8+DufyBsgQ==" + }, + "postcss": { + "version": "8.4.31", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz", + "integrity": "sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==", + "requires": { + "nanoid": "^3.3.6", + "picocolors": "^1.0.0", + "source-map-js": "^1.0.2" + } + }, + "stylis": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.1.tgz", + "integrity": "sha512-EQepAV+wMsIaGVGX1RECzgrcqRRU/0sYOHkeLsZ3fzHaHXZy4DaOOX0vOlGQdlsjkh3mFHAIlVimpwAs4dslyQ==" + }, + "tslib": { + "version": "2.5.0", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.5.0.tgz", + "integrity": "sha512-336iVw3rtn2BUK7ORdIAHTyxHGRIHVReokCR3XjbckJMK7ms8FysBfhLR8IXnAgy7T0PTPNBWKiH514FOW/WSg==" + } } }, "stylehacks": { @@ -32805,9 +32856,9 @@ } }, "stylis": { - "version": "4.3.0", - "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.0.tgz", - "integrity": "sha512-E87pIogpwUsUwXw7dNyU4QDjdgVMy52m+XEOPEKUn161cCzWjjhPSQhByfd1CcNvrOLnXQ6OnnZDwnJrz/Z4YQ==" + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/stylis/-/stylis-4.3.2.tgz", + "integrity": "sha512-bhtUjWd/z6ltJiQwg0dUfxEJ+W+jdqQd8TbWLWyeIJHlnsqmGLRFFd8e5mA0AZi/zx90smXRlN66YMTcaSFifg==" }, "sucrase": { "version": "3.34.0", diff --git a/aiida_workgraph/web/frontend/package.json b/aiida_workgraph/web/frontend/package.json index 6607055e..31ebcffd 100644 --- a/aiida_workgraph/web/frontend/package.json +++ b/aiida_workgraph/web/frontend/package.json @@ -29,18 +29,18 @@ "react-scripts": "5.0.1", "react-syntax-highlighter": "^15.5.0", "react-toastify": "^9.1.3", - "rete": "^2.0.2", - "rete-area-3d-plugin": "^2.0.3", - "rete-area-plugin": "^2.0.1", - "rete-auto-arrange-plugin": "^2.0.0", - "rete-connection-plugin": "^2.0.0", - "rete-connection-reroute-plugin": "^2.0.0", - "rete-context-menu-plugin": "^2.0.0", - "rete-minimap-plugin": "^2.0.1", - "rete-react-plugin": "^2.0.4", - "rete-readonly-plugin": "^2.0.0", - "rete-render-utils": "^2.0.1", - "styled-components": "^5.3.11", + "rete": "2.0.3", + "rete-area-plugin": "2.0.3", + "rete-auto-arrange-plugin": "2.0.1", + "rete-scopes-plugin": "2.1.0", + "rete-connection-plugin": "2.0.1", + "rete-connection-reroute-plugin": "2.0.0", + "rete-context-menu-plugin": "2.0.0", + "rete-minimap-plugin": "2.0.1", + "rete-react-plugin": "2.0.5", + "rete-readonly-plugin": "2.0.0", + "rete-render-utils": "2.0.2", + "styled-components": "6.1.8", "three": "^0.156.1", "typescript": "^4.9.5", "vis-timeline": "^7.7.3", diff --git a/aiida_workgraph/web/frontend/src/App.js b/aiida_workgraph/web/frontend/src/App.js index afd69471..54170bee 100644 --- a/aiida_workgraph/web/frontend/src/App.js +++ b/aiida_workgraph/web/frontend/src/App.js @@ -1,4 +1,3 @@ -import React from 'react'; import { BrowserRouter as Router, Routes, Route } from 'react-router-dom'; import Home from './components/Home'; import WorkGraphTable from './components/WorkGraphTable'; diff --git a/aiida_workgraph/web/frontend/src/rete/default.ts b/aiida_workgraph/web/frontend/src/rete/default.ts index c86d3c13..412e9649 100644 --- a/aiida_workgraph/web/frontend/src/rete/default.ts +++ b/aiida_workgraph/web/frontend/src/rete/default.ts @@ -5,6 +5,7 @@ import { Presets as ConnectionPresets } from "rete-connection-plugin"; import { ReactPlugin, Presets, ReactArea2D } from "rete-react-plugin"; +import { ScopesPlugin, Presets as ScopesPresets } from "rete-scopes-plugin"; import { MinimapExtra, MinimapPlugin } from "rete-minimap-plugin"; import { ContextMenuPlugin, @@ -47,10 +48,52 @@ interface NodeMap { } +export async function loadJSON(editor: NodeEditor, area: any, workgraphData: any) { + + // Adding nodes based on workgraphData + const nodeMap: NodeMap = {}; // To keep track of created nodes for linking + for (const nodeId in workgraphData.nodes) { + const nodeData = workgraphData.nodes[nodeId]; + const node = createDynamicNode(nodeData); + await editor.addNode(node); + nodeMap[nodeId] = node; // Storing reference to the node + } + // Adding connections based on workgraphData + workgraphData.links.forEach(async (link: LinkData) => { // Specify the type of link here + const fromNode = nodeMap[link.from_node]; + const toNode = nodeMap[link.to_node]; + if (fromNode && toNode) { + await editor.addConnection(new Connection(fromNode, link.from_socket, toNode, link.to_socket)); + } + }); + + // Add while zones + console.log("Adding while zone: "); + for (const nodeId in workgraphData.nodes) { + const nodeData = workgraphData.nodes[nodeId]; + // if node_type is "WHILE", find all + console.log("Node type: ", nodeData['node_type']); + if (nodeData['node_type'] === "WHILE") { + // find the node + const node = nodeMap[nodeData.label]; + const tasks = nodeData['properties']['tasks']['value']; + // find the id of all nodes in the editor that has a label in while_zone + for (const nodeId in tasks) { + const node1 = nodeMap[tasks[nodeId]]; + console.log("Setting parent of node", node1, "to", node); + node1.parent = node.id; + area.update('node', node1.id); + } + area.update('node', node.id); + } + } +} + class Node extends ClassicPreset.Node { width = 180; height = 100; + parent?: string; } class Connection extends ClassicPreset.Connection {} @@ -90,6 +133,7 @@ export async function createEditor(container: HTMLElement, workgraphData: any) { const area = new AreaPlugin(container); const connection = new ConnectionPlugin(); const render = new ReactPlugin(); + const scopes = new ScopesPlugin(); const arrange = new AutoArrangePlugin(); const contextMenu = new ContextMenuPlugin({ items: ContextMenuPresets.classic.setup([ @@ -108,6 +152,7 @@ export async function createEditor(container: HTMLElement, workgraphData: any) { render.addPreset(Presets.minimap.setup({ size: 200 })); connection.addPreset(ConnectionPresets.classic.setup()); + scopes.addPreset(ScopesPresets.classic.setup()); const applier = new ArrangeAppliers.TransitionApplier({ duration: 500, @@ -122,28 +167,14 @@ export async function createEditor(container: HTMLElement, workgraphData: any) { editor.use(area); // area.use(connection); area.use(render); + area.use(scopes); area.use(arrange); area.use(contextMenu); area.use(minimap); AreaExtensions.simpleNodesOrder(area); - // Adding nodes based on workgraphData - const nodeMap: NodeMap = {}; // To keep track of created nodes for linking - for (const nodeId in workgraphData.nodes) { - const nodeData = workgraphData.nodes[nodeId]; - const node = createDynamicNode(nodeData); - await editor.addNode(node); - nodeMap[nodeId] = node; // Storing reference to the node - } - // Adding connections based on workgraphData - workgraphData.links.forEach(async (link: LinkData) => { // Specify the type of link here - const fromNode = nodeMap[link.from_node]; - const toNode = nodeMap[link.to_node]; - if (fromNode && toNode) { - await editor.addConnection(new Connection(fromNode, link.from_socket, toNode, link.to_socket)); - } - }); + await loadJSON(editor, area, workgraphData); async function layout(animate: boolean) { await arrange.layout({ applier: animate ? applier : undefined }); diff --git a/aiida_workgraph/widget/js/default_rete.ts b/aiida_workgraph/widget/js/default_rete.ts index 261efc66..04fe9aac 100644 --- a/aiida_workgraph/widget/js/default_rete.ts +++ b/aiida_workgraph/widget/js/default_rete.ts @@ -5,6 +5,7 @@ import { Presets as ConnectionPresets } from "rete-connection-plugin"; import { ReactPlugin, Presets, ReactArea2D } from "rete-react-plugin"; +import { ScopesPlugin, Presets as ScopesPresets } from "rete-scopes-plugin"; import { MinimapExtra, MinimapPlugin } from "rete-minimap-plugin"; import { ContextMenuPlugin, @@ -90,6 +91,23 @@ export async function loadJSON(editor, area, layout, workgraphData) { workgraphData.links.forEach(async (link: LinkData) => { // Specify the type of link here await addLink(editor, area, layout, link); }); + + // Add while zones + console.log("Adding while zone: "); + for (const nodeId in workgraphData.nodes) { + const nodeData = workgraphData.nodes[nodeId]; + // if node_type is "WHILE", find all + if (nodeData['node_type'] === "WHILE") { + // find the node + const node = editor.nodeMap[nodeData.label]; + const tasks = nodeData['properties']['tasks']['value']; + // find the id of all nodes in the editor that has a label in while_zone + for (const nodeId in tasks) { + const node1 = editor.nodeMap[tasks[nodeId]]; + node1.parent = node.id; + } + } + } } export async function addNode(editor, area, nodeData) { @@ -161,6 +179,7 @@ export async function createEditor(container: HTMLElement, settings: any) { const area = new AreaPlugin(container); const connection = new ConnectionPlugin(); const render = new ReactPlugin(); + const scopes = new ScopesPlugin(); const arrange = new AutoArrangePlugin(); const contextMenu = new ContextMenuPlugin({ items: ContextMenuPresets.classic.setup([ @@ -170,15 +189,17 @@ export async function createEditor(container: HTMLElement, settings: any) { boundViewport: true }); - AreaExtensions.selectableNodes(area, AreaExtensions.selector(), { - accumulating: AreaExtensions.accumulateOnCtrl() - }); + const selector = AreaExtensions.selector(); + const accumulating = AreaExtensions.accumulateOnCtrl(); + + AreaExtensions.selectableNodes(area, selector, { accumulating }); render.addPreset(Presets.classic.setup()); render.addPreset(Presets.contextMenu.setup()); render.addPreset(Presets.minimap.setup({ size: 200 })); connection.addPreset(ConnectionPresets.classic.setup()); + scopes.addPreset(ScopesPresets.classic.setup()); const applier = new ArrangeAppliers.TransitionApplier({ duration: 500, @@ -193,6 +214,7 @@ export async function createEditor(container: HTMLElement, settings: any) { editor.use(area); // area.use(connection); area.use(render); + area.use(scopes); area.use(arrange); area.use(contextMenu); if (settings.minimap) { diff --git a/aiida_workgraph/widget/package-lock.json b/aiida_workgraph/widget/package-lock.json index 282b32d8..95ff827b 100644 --- a/aiida_workgraph/widget/package-lock.json +++ b/aiida_workgraph/widget/package-lock.json @@ -19,7 +19,8 @@ "rete-minimap-plugin": "^2.0.1", "rete-react-plugin": "^2.0.4", "rete-readonly-plugin": "^2.0.0", - "rete-render-utils": "^2.0.1" + "rete-render-utils": "^2.0.1", + "rete-scopes-plugin": "2.1.0" }, "devDependencies": { "@types/react": "^18.2.61", @@ -2162,6 +2163,18 @@ "rete-area-plugin": "^2.0.0" } }, + "node_modules/rete-scopes-plugin": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/rete-scopes-plugin/-/rete-scopes-plugin-2.1.0.tgz", + "integrity": "sha512-qpbTvpPlKb52vPX56XA41RjK4JARWE07k3RnjvM4sNYfrdEszX8m0ZaTAeWGaslEuVdlY/rAOl6NxCbiU6E0sg==", + "dependencies": { + "@babel/runtime": "^7.21.0" + }, + "peerDependencies": { + "rete": "^2.0.1", + "rete-area-plugin": "^2.0.0" + } + }, "node_modules/scheduler": { "version": "0.23.0", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.0.tgz", @@ -3928,6 +3941,14 @@ "@babel/runtime": "^7.21.0" } }, + "rete-scopes-plugin": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/rete-scopes-plugin/-/rete-scopes-plugin-2.1.0.tgz", + "integrity": "sha512-qpbTvpPlKb52vPX56XA41RjK4JARWE07k3RnjvM4sNYfrdEszX8m0ZaTAeWGaslEuVdlY/rAOl6NxCbiU6E0sg==", + "requires": { + "@babel/runtime": "^7.21.0" + } + }, "scheduler": { "version": "0.23.0", "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.0.tgz", diff --git a/aiida_workgraph/widget/package.json b/aiida_workgraph/widget/package.json index e51a6aae..18e9a5c0 100644 --- a/aiida_workgraph/widget/package.json +++ b/aiida_workgraph/widget/package.json @@ -13,6 +13,7 @@ "rete-area-3d-plugin": "^2.0.3", "rete-area-plugin": "^2.0.1", "rete-auto-arrange-plugin": "^2.0.0", + "rete-scopes-plugin": "2.1.0", "rete-connection-plugin": "^2.0.0", "rete-connection-reroute-plugin": "^2.0.0", "rete-context-menu-plugin": "^2.0.0", diff --git a/aiida_workgraph/widget/src/widget/html_template.py b/aiida_workgraph/widget/src/widget/html_template.py index 28524101..baa8e62f 100644 --- a/aiida_workgraph/widget/src/widget/html_template.py +++ b/aiida_workgraph/widget/src/widget/html_template.py @@ -21,6 +21,7 @@ + + + +
+ + + diff --git a/docs/source/howto/html/test_while_workgraph.html b/docs/source/howto/html/test_while_workgraph.html new file mode 100644 index 00000000..0080b8b2 --- /dev/null +++ b/docs/source/howto/html/test_while_workgraph.html @@ -0,0 +1,281 @@ + + + + + + + Rete.js with React in Vanilla JS + + + + + + + + + + + + + + + + + + + + + +
+ + + diff --git a/docs/source/howto/while.ipynb b/docs/source/howto/while.ipynb index b388ec66..de58f4c4 100644 --- a/docs/source/howto/while.ipynb +++ b/docs/source/howto/while.ipynb @@ -16,96 +16,813 @@ "## Introduction\n", "With the while loop we can execute a set of tasks as long as the condition is true. In this tutorial, you will learn how to use `while` loop in WorkGraph.\n", "\n", - "Load the AiiDA profile." + "There are two ways to implement a `while` loop in WorkGraph:\n", + "\n", + "- Using `While` Task\n", + "- Using `While` WorkGraph\n", + "\n", + "## Using `While` Task\n", + "\n", + "The `While` task allows running tasks many times in a loop. One can add a `While` task to the WorkGraph and specify the conditions and tasks to run in the loop.\n", + "\n", + "```python\n", + "wg.add_task(\"While\",\n", + " max_iterations=100,\n", + " conditions=[\"should_run\"],\n", + " tasks=[\"add2\", \"multiply1\", \"compare1\"])\n", + "```\n", + "\n", + "Parameters:\n", + "\n", + "- **max_iterations**\n", + "- **conditions**: a list of variables from the `context`. When any of them is `False`, the while loop exit.\n", + "- **tasks**: a list of names of the the tasks inside the while loop.\n", + "\n", + "\n", + "### Example\n", + "\n", + "Suppose we want to calculate the following workflow, the tasks for each step are shown:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "8f5e7642", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result: uuid: 516fe11f-8060-4773-925e-b3f5ba22f290 (pk: 97696) value: 63\n" + ] + } + ], + "source": [ + "from aiida.engine import calcfunction\n", + "from aiida import load_profile\n", + "\n", + "load_profile()\n", + "\n", + "@calcfunction\n", + "def compare(x, y):\n", + " return x < y\n", + "\n", + "@calcfunction\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "@calcfunction\n", + "def multiply(x, y):\n", + " return x*y\n", + "\n", + "#-------------------------------------------------------------------\n", + "# start while block\n", + "n = add(1, 1) # task add1\n", + "# start while loop\n", + "while compare(n, 50): # task compare\n", + " n = add(n, 1) # task add2\n", + " n = multiply(n, 2) # task multiply1\n", + "# end while block\n", + "z = add(n, 1) # task add3\n", + "#-------------------------------------------------------------------\n", + "\n", + "print(\"Result: \", z)" + ] + }, + { + "cell_type": "markdown", + "id": "65f4c44d", + "metadata": {}, + "source": [ + "### Create the workflow\n", + "Now, let'use create the workflow using the `While` task." ] }, { "cell_type": "code", - "execution_count": 1, - "id": "c6b83fb5", + "execution_count": 27, + "id": "8ee799d2-0b5b-4609-957f-6b3f2cd451f0", "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "\n", + " \n", + " " + ], "text/plain": [ - "Profile" + "" ] }, - "execution_count": 1, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "%load_ext aiida\n", - "from aiida import load_profile\n", - "load_profile()" + "from aiida_workgraph import task, WorkGraph\n", + "\n", + "wg = WorkGraph(\"test_while_task\")\n", + "# set a context variable before running.\n", + "wg.context = {\"should_run\": True}\n", + "add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n", + "add1.set_context({\"result\": \"n\"})\n", + "#---------------------------------------------------------------------\n", + "# Create the tasks in the while loop.\n", + "add2 = wg.add_task(add, name=\"add2\", x=\"{{n}}\", y=1)\n", + "add2.wait.append(\"add1\")\n", + "multiply1 = wg.add_task(multiply, name=\"multiply1\",\n", + " x=add2.outputs[\"result\"],\n", + " y=2)\n", + "# update the context variable\n", + "multiply1.set_context({\"result\": \"n\"})\n", + "compare1 = wg.add_task(compare, name=\"compare1\", x=multiply1.outputs[\"result\"], y=50)\n", + "# Save the `result` of compare1 task as context.should_run, and used as condition\n", + "compare1.set_context({\"result\": \"should_run\"})\n", + "# Create the while tasks\n", + "while1 = wg.add_task(\"While\", max_iterations=100,\n", + " conditions=[\"should_run\"],\n", + " tasks=[\"add2\", \"multiply1\", \"compare1\"])\n", + "#---------------------------------------------------------------------\n", + "add3 = wg.add_task(add, name=\"add3\", x=1, y=1)\n", + "wg.add_link(multiply1.outputs[\"result\"], add3.inputs[\"x\"])\n", + "wg.to_html()\n", + "# comment out the following line to visualize the graph in the notebook\n", + "# wg" ] }, { "cell_type": "markdown", - "id": "30719f9a", + "id": "6bccc088", "metadata": {}, "source": [ - "## First workflow: while\n", - "Suppose we want to calculate:\n", - "```python\n", - "# start while block\n", - "n=1\n", - "while n < 100:\n", - " n = n*2\n", - " n = n + 3\n", - "# end while block\n", - "z = n+1\n", - "```" + "In the GUI, **While** task is shown as a **While Zone** with all its child tasks inside the Zone. \n", + "The while zone does not have data input and output sockets. Tasks outside the while zone can link to the tasks inside the zone directly." ] }, { "cell_type": "markdown", - "id": "0f46d277", + "id": "d25beb02-ee82-4a27-ae48-edc5c147904c", "metadata": {}, "source": [ - "### Create task\n", - "We first create the tasks to do the calculation." + "### Submit the WorkGraph and check the results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9ebf35aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WorkGraph process created, PK: 97508\n", + "State of WorkGraph: FINISHED\n", + "Result of add1 : uuid: db4ff8aa-39bc-4f94-ba1b-04fabf9a93a5 (pk: 97551) value: 63\n" + ] + } + ], + "source": [ + "wg.submit(wait=True)\n", + "print(\"State of WorkGraph: {}\".format(wg.state))\n", + "print('Result of add1 : {}'.format(add3.outputs[\"result\"].value))" + ] + }, + { + "cell_type": "markdown", + "id": "125ac629", + "metadata": {}, + "source": [ + "Generate node graph from the AiiDA process,and we can see that when `compare1` node outputs `False`, the workgraph stops." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0060e380", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "N97508\n", + "\n", + "WorkGraph<test_while> (97508)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97511\n", + "\n", + "add (97511)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97511\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", + "\n", + "\n", + "\n", + "N97514\n", + "\n", + "add (97514)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97514\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", + "\n", + "\n", + "\n", + "N97517\n", + "\n", + "multiply (97517)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97517\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", + "\n", + "\n", + "\n", + "N97520\n", + "\n", + "compare (97520)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97520\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", + "\n", + "\n", + "\n", + "N97523\n", + "\n", + "add (97523)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97523\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", + "\n", + "\n", + "\n", + "N97526\n", + "\n", + "multiply (97526)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97526\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", + "\n", + "\n", + "\n", + "N97529\n", + "\n", + "compare (97529)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97529\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", + "\n", + "\n", + "\n", + "N97532\n", + "\n", + "add (97532)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97532\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", + "\n", + "\n", + "\n", + "N97535\n", + "\n", + "multiply (97535)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97535\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", + "\n", + "\n", + "\n", + "N97538\n", + "\n", + "compare (97538)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97538\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", + "\n", + "\n", + "\n", + "N97541\n", + "\n", + "add (97541)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97541\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", + "\n", + "\n", + "\n", + "N97544\n", + "\n", + "multiply (97544)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97544\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", + "\n", + "\n", + "\n", + "N97547\n", + "\n", + "compare (97547)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97547\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", + "\n", + "\n", + "\n", + "N97550\n", + "\n", + "add (97550)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97508->N97550\n", + "\n", + "\n", + "CALL_CALC\n", + "add3\n", + "\n", + "\n", + "\n", + "N97552\n", + "\n", + "Int (97552)\n", + "value: 0\n", + "\n", + "\n", + "\n", + "N97508->N97552\n", + "\n", + "\n", + "RETURN\n", + "execution_count\n", + "\n", + "\n", + "\n", + "N97512\n", + "\n", + "Int (97512)\n", + "value: 2\n", + "\n", + "\n", + "\n", + "N97511->N97512\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97512->N97514\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97515\n", + "\n", + "Int (97515)\n", + "value: 3\n", + "\n", + "\n", + "\n", + "N97514->N97515\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97515->N97517\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97518\n", + "\n", + "Int (97518)\n", + "value: 6\n", + "\n", + "\n", + "\n", + "N97517->N97518\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97518->N97520\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97518->N97523\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97521\n", + "\n", + "Bool (97521)\n", + "True\n", + "\n", + "\n", + "\n", + "N97520->N97521\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97524\n", + "\n", + "Int (97524)\n", + "value: 7\n", + "\n", + "\n", + "\n", + "N97523->N97524\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97524->N97526\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97527\n", + "\n", + "Int (97527)\n", + "value: 14\n", + "\n", + "\n", + "\n", + "N97526->N97527\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97527->N97529\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97527->N97532\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97530\n", + "\n", + "Bool (97530)\n", + "True\n", + "\n", + "\n", + "\n", + "N97529->N97530\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97533\n", + "\n", + "Int (97533)\n", + "value: 15\n", + "\n", + "\n", + "\n", + "N97532->N97533\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97533->N97535\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97536\n", + "\n", + "Int (97536)\n", + "value: 30\n", + "\n", + "\n", + "\n", + "N97535->N97536\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97536->N97538\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97536->N97541\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97539\n", + "\n", + "Bool (97539)\n", + "True\n", + "\n", + "\n", + "\n", + "N97538->N97539\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97542\n", + "\n", + "Int (97542)\n", + "value: 31\n", + "\n", + "\n", + "\n", + "N97541->N97542\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97542->N97544\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97545\n", + "\n", + "Int (97545)\n", + "value: 62\n", + "\n", + "\n", + "\n", + "N97544->N97545\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97545->N97547\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97545->N97550\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97548\n", + "\n", + "Bool (97548)\n", + "False\n", + "\n", + "\n", + "\n", + "N97547->N97548\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97551\n", + "\n", + "Int (97551)\n", + "value: 63\n", + "\n", + "\n", + "\n", + "N97550->N97551\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from aiida_workgraph.utils import generate_node_graph\n", + "generate_node_graph(wg.pk)" ] }, { "cell_type": "markdown", - "id": "9e6360d8", + "id": "ff7e78ce", "metadata": {}, - "source": [] + "source": [ + "## `While` WorkGraph\n", + "Create a WorkGraph will repeat itself based on the conditions." + ] }, { "cell_type": "code", - "execution_count": 2, - "id": "11e3bca1-dda6-44e9-9585-54feeda7e7db", + "execution_count": 29, + "id": "4d0598d3", "metadata": {}, "outputs": [], "source": [ - "from aiida_workgraph import task, WorkGraph\n", - "from aiida.engine import calcfunction\n", - "from aiida.orm import Int\n", - "\n", - "# we need a compare task for `n<100`,\n", - "# it's a normal function instead of a calcfunction\n", - "@task()\n", - "def compare(x, y):\n", - " return x < y\n", - "\n", - "# define multiply task for n*2\n", - "@calcfunction\n", - "def multiply(x, y):\n", - " return x*y\n", - "\n", - "# define add task for n+3\n", - "@calcfunction\n", - "def add(x, y):\n", - " return x + y\n", - "\n", - "# Create a WorkGraph will repeat itself based on the conditions\n", - "# then we output the result of from the context (context)\n", + "# Output the result of from the context (context)\n", "@task.graph_builder(outputs = [{\"name\": \"result\", \"from\": \"context.n\"}])\n", - "def add_multiply_while(n, limit):\n", + "def add_multiply_while(n, limit=50):\n", " wg = WorkGraph()\n", " # tell the engine that this is a `while` workgraph\n", " wg.workgraph_type = \"WHILE\"\n", @@ -113,19 +830,18 @@ " wg.conditions = [\"compare1.result\"]\n", " # set a context variable before running.\n", " wg.context = {\"n\": n}\n", - " wg.add_task(compare, name=\"compare1\", x=\"{{n}}\", y=Int(limit))\n", - " multiply1 = wg.add_task(multiply, name=\"multiply1\", x=\"{{ n }}\", y=Int(2))\n", - " add1 = wg.add_task(add, name=\"add1\", y=3)\n", + " wg.add_task(compare, name=\"compare1\", x=\"{{n}}\", y=limit)\n", + " add1 = wg.add_task(add, name=\"add1\", x=\"{{ n }}\", y=1)\n", + " multiply1 = wg.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"],\n", + " y=2)\n", " # update the context variable\n", - " add1.set_context({\"result\": \"n\"})\n", - " wg.add_link(multiply1.outputs[\"result\"], add1.inputs[\"x\"])\n", - " # don't forget to return the workgraph\n", + " multiply1.set_context({\"result\": \"n\"})\n", " return wg" ] }, { "cell_type": "markdown", - "id": "65f4c44d", + "id": "75a9f021", "metadata": {}, "source": [ "### Create the workflow\n", @@ -134,57 +850,77 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "8ee799d2-0b5b-4609-957f-6b3f2cd451f0", + "execution_count": 30, + "id": "336e5ab7", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "WorkGraph process created, PK: 4132\n" - ] + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "wg = WorkGraph(\"test_while\")\n", - "while1 = wg.add_task(add_multiply_while, n=Int(1), limit=50)\n", - "add1 = wg.add_task(add, y=Int(1))\n", - "wg.add_link(while1.outputs[\"result\"], add1.inputs[\"x\"])\n", - "wg.submit(wait=True)" + "wg = WorkGraph(\"test_while_workgraph\")\n", + "add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n", + "while1 = wg.add_task(add_multiply_while, n=add1.outputs[\"result\"],\n", + " limit=50)\n", + "add2 = wg.add_task(add, name=\"add2\", y=1)\n", + "wg.add_link(while1.outputs[\"result\"], add2.inputs[\"x\"])\n", + "wg.to_html()\n", + "# wg" ] }, { "cell_type": "markdown", - "id": "d25beb02-ee82-4a27-ae48-edc5c147904c", + "id": "a247c1b7", "metadata": {}, "source": [ - "### Check status and results\n" + "### Submit the WorkGraph and check the results" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "9ebf35aa", + "execution_count": 31, + "id": "17f6d666", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "WorkGraph process created, PK: 97697\n", "State of WorkGraph: FINISHED\n", - "Result of add1 : uuid: 49ff90cd-af97-4b80-8e25-f05c621bbd2b (pk: 4162) value: 62\n" + "Result of add2 : uuid: eeec1b01-aa2c-423f-8295-3d3f0d3e4245 (pk: 97747) value: 63\n" ] } ], "source": [ + "wg.submit(wait=True)\n", "print(\"State of WorkGraph: {}\".format(wg.state))\n", - "print('Result of add1 : {}'.format(add1.outputs[\"result\"].value))" + "print('Result of add2 : {}'.format(wg.tasks[\"add2\"].outputs[\"result\"].value))" ] }, { "cell_type": "markdown", - "id": "125ac629", + "id": "f9349800", "metadata": {}, "source": [ "Generate node graph from the AiiDA process,and we can see that when `compare1` node outputs `False`, the workgraph stops." @@ -192,8 +928,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "0060e380", + "execution_count": 32, + "id": "bc6297de", "metadata": {}, "outputs": [ { @@ -205,453 +941,649 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "N9523\n", - "\n", - "WorkGraph: test_while (9523)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97697\n", + "\n", + "WorkGraph<test_while_workgraph> (97697)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526\n", - "\n", - "WorkGraph: add_multiply_while1 (9526)\n", - "State: finished\n", - "Exit Code: 0\n", + "\n", + "\n", + "N97700\n", + "\n", + "add (97700)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9523->N9526\n", - "\n", - "\n", - "CALL_WORK\n", - "add_multiply_while1\n", + "\n", + "\n", + "N97697->N97700\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", "\n", - "\n", - "\n", - "N9552\n", - "\n", - "add (9552)\n", - "State: finished\n", - "Exit Code: 0\n", + "\n", + "\n", + "N97702\n", + "\n", + "WorkGraph<add_multiply_while2> (97702)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9523->N9552\n", - "\n", - "\n", - "CALL_CALC\n", - "add2\n", + "\n", + "\n", + "N97697->N97702\n", + "\n", + "\n", + "CALL_WORK\n", + "add_multiply_while2\n", "\n", - "\n", - "\n", - "N9554\n", - "\n", - "Int (9554)\n", - "value: 0\n", + "\n", + "\n", + "N97746\n", + "\n", + "add (97746)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9523->N9554\n", - "\n", - "\n", - "RETURN\n", - "execution_count\n", + "\n", + "\n", + "N97697->N97746\n", + "\n", + "\n", + "CALL_CALC\n", + "add2\n", "\n", - "\n", - "\n", - "N9521\n", - "\n", - "Int (9521)\n", - "value: 1\n", + "\n", + "\n", + "N97748\n", + "\n", + "Int (97748)\n", + "value: 0\n", "\n", - "\n", - "\n", - "N9521->N9523\n", - "\n", - "\n", - "INPUT_WORK\n", - "wg__nodes__add_multiply_while1__properties__n__value\n", + "\n", + "\n", + "N97697->N97748\n", + "\n", + "\n", + "RETURN\n", + "execution_count\n", "\n", - "\n", + "\n", "\n", - "N9522\n", - "\n", - "Int (9522)\n", - "value: 1\n", + "N97701\n", + "\n", + "Int (97701)\n", + "value: 2\n", "\n", - "\n", - "\n", - "N9522->N9523\n", - "\n", - "\n", - "INPUT_WORK\n", - "wg__nodes__add2__properties__y__value\n", + "\n", + "\n", + "N97700->N97701\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97701->N97702\n", + "\n", + "\n", + "INPUT_WORK\n", + "wg__context__n\n", "\n", - "\n", + "\n", "\n", - "N9527\n", - "\n", - "multiply (9527)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97705\n", + "\n", + "compare (97705)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9527\n", - "\n", - "\n", - "CALL_CALC\n", - "multiply1\n", + "\n", + "\n", + "N97702->N97705\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", "\n", - "\n", + "\n", "\n", - "N9530\n", - "\n", - "add (9530)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97709\n", + "\n", + "add (97709)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9530\n", - "\n", - "\n", - "CALL_CALC\n", - "add1\n", + "\n", + "\n", + "N97702->N97709\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", "\n", - "\n", + "\n", "\n", - "N9533\n", - "\n", - "multiply (9533)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97712\n", + "\n", + "multiply (97712)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9533\n", - "\n", - "\n", - "CALL_CALC\n", - "multiply1\n", + "\n", + "\n", + "N97702->N97712\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", "\n", - "\n", + "\n", "\n", - "N9536\n", - "\n", - "add (9536)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97715\n", + "\n", + "compare (97715)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9536\n", - "\n", - "\n", - "CALL_CALC\n", - "add1\n", + "\n", + "\n", + "N97702->N97715\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", "\n", - "\n", + "\n", "\n", - "N9539\n", - "\n", - "multiply (9539)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97718\n", + "\n", + "add (97718)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9539\n", - "\n", - "\n", - "CALL_CALC\n", - "multiply1\n", + "\n", + "\n", + "N97702->N97718\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", "\n", - "\n", + "\n", "\n", - "N9542\n", - "\n", - "add (9542)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97721\n", + "\n", + "multiply (97721)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9542\n", - "\n", - "\n", - "CALL_CALC\n", - "add1\n", + "\n", + "\n", + "N97702->N97721\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", "\n", - "\n", + "\n", "\n", - "N9545\n", - "\n", - "multiply (9545)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97724\n", + "\n", + "compare (97724)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9545\n", - "\n", - "\n", - "CALL_CALC\n", - "multiply1\n", + "\n", + "\n", + "N97702->N97724\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", "\n", - "\n", + "\n", "\n", - "N9548\n", - "\n", - "add (9548)\n", - "State: finished\n", - "Exit Code: 0\n", + "N97727\n", + "\n", + "add (97727)\n", + "State: finished\n", + "Exit Code: 0\n", "\n", - "\n", - "\n", - "N9526->N9548\n", - "\n", - "\n", - "CALL_CALC\n", - "add1\n", + "\n", + "\n", + "N97702->N97727\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", "\n", - "\n", - "\n", - "N9549\n", - "\n", - "Int (9549)\n", - "value: 61\n", + "\n", + "\n", + "N97730\n", + "\n", + "multiply (97730)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97702->N97730\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", "\n", - "\n", + "\n", + "\n", + "N97733\n", + "\n", + "compare (97733)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97702->N97733\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", + "\n", + "\n", + "\n", + "N97736\n", + "\n", + "add (97736)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97702->N97736\n", + "\n", + "\n", + "CALL_CALC\n", + "add1\n", + "\n", + "\n", + "\n", + "N97739\n", + "\n", + "multiply (97739)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", + "\n", + "N97702->N97739\n", + "\n", + "\n", + "CALL_CALC\n", + "multiply1\n", + "\n", + "\n", + "\n", + "N97740\n", + "\n", + "Int (97740)\n", + "value: 62\n", + "\n", + "\n", + "\n", + "N97702->N97740\n", + "\n", + "\n", + "RETURN\n", + "result\n", + "\n", + "\n", + "\n", + "N97742\n", + "\n", + "compare (97742)\n", + "State: finished\n", + "Exit Code: 0\n", + "\n", + "\n", "\n", - "N9526->N9549\n", - "\n", - "\n", - "RETURN\n", - "group_outputs__result\n", + "N97702->N97742\n", + "\n", + "\n", + "CALL_CALC\n", + "compare1\n", "\n", - "\n", - "\n", - "N9550\n", - "\n", - "Int (9550)\n", - "value: 4\n", + "\n", + "\n", + "N97744\n", + "\n", + "Int (97744)\n", + "value: 4\n", "\n", - "\n", - "\n", - "N9526->N9550\n", - "\n", - "\n", - "RETURN\n", - "execution_count\n", + "\n", + "\n", + "N97702->N97744\n", + "\n", + "\n", + "RETURN\n", + "execution_count\n", "\n", - "\n", + "\n", "\n", - "N9528\n", - "\n", - "Int (9528)\n", - "value: 2\n", - "\n", - "\n", - "\n", - "N9527->N9528\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "N97706\n", + "\n", + "Bool (97706)\n", + "True\n", "\n", - "\n", - "\n", - "N9528->N9530\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "\n", + "\n", + "N97705->N97706\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", + "\n", "\n", - "N9531\n", - "\n", - "Int (9531)\n", - "value: 5\n", + "N97710\n", + "\n", + "Int (97710)\n", + "value: 3\n", "\n", - "\n", - "\n", - "N9530->N9531\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "\n", + "\n", + "N97709->N97710\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9531->N9533\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "\n", + "\n", + "N97710->N97712\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", "\n", - "\n", + "\n", "\n", - "N9534\n", - "\n", - "Int (9534)\n", - "value: 10\n", + "N97713\n", + "\n", + "Int (97713)\n", + "value: 6\n", "\n", - "\n", - "\n", - "N9533->N9534\n", - "\n", - "\n", - "CREATE\n", - "result\n", - "\n", - "\n", + "\n", "\n", - "N9534->N9536\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "N97712->N97713\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9537\n", - "\n", - "Int (9537)\n", - "value: 13\n", + "\n", + "\n", + "N97713->N97715\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", "\n", - "\n", - "\n", - "N9536->N9537\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "\n", + "\n", + "N97713->N97718\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97716\n", + "\n", + "Bool (97716)\n", + "True\n", "\n", - "\n", + "\n", "\n", - "N9537->N9539\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "N97715->N97716\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", + "\n", "\n", - "N9540\n", - "\n", - "Int (9540)\n", - "value: 26\n", + "N97719\n", + "\n", + "Int (97719)\n", + "value: 7\n", "\n", - "\n", - "\n", - "N9539->N9540\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "\n", + "\n", + "N97718->N97719\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9540->N9542\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "\n", + "\n", + "N97719->N97721\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", "\n", - "\n", + "\n", "\n", - "N9543\n", - "\n", - "Int (9543)\n", - "value: 29\n", + "N97722\n", + "\n", + "Int (97722)\n", + "value: 14\n", "\n", - "\n", - "\n", - "N9542->N9543\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "\n", + "\n", + "N97721->N97722\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9543->N9545\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "\n", + "\n", + "N97722->N97724\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", "\n", - "\n", + "\n", + "\n", + "N97722->N97727\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", "\n", - "N9546\n", - "\n", - "Int (9546)\n", - "value: 58\n", + "N97725\n", + "\n", + "Bool (97725)\n", + "True\n", "\n", - "\n", - "\n", - "N9545->N9546\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "\n", + "\n", + "N97724->N97725\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9546->N9548\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "\n", + "\n", + "N97728\n", + "\n", + "Int (97728)\n", + "value: 15\n", + "\n", + "\n", + "\n", + "N97727->N97728\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97728->N97730\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97731\n", + "\n", + "Int (97731)\n", + "value: 30\n", + "\n", + "\n", + "\n", + "N97730->N97731\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97731->N97733\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97731->N97736\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97734\n", + "\n", + "Bool (97734)\n", + "True\n", + "\n", + "\n", + "\n", + "N97733->N97734\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", + "\n", + "\n", + "N97737\n", + "\n", + "Int (97737)\n", + "value: 31\n", + "\n", + "\n", "\n", - "N9548->N9549\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "N97736->N97737\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9549->N9552\n", - "\n", - "\n", - "INPUT_CALC\n", - "x\n", + "\n", + "\n", + "N97737->N97739\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", "\n", - "\n", - "\n", - "N9553\n", - "\n", - "Int (9553)\n", - "value: 62\n", + "\n", + "\n", + "N97739->N97740\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", - "\n", - "\n", - "N9552->N9553\n", - "\n", - "\n", - "CREATE\n", - "result\n", + "\n", + "\n", + "N97740->N97742\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97740->N97746\n", + "\n", + "\n", + "INPUT_CALC\n", + "x\n", + "\n", + "\n", + "\n", + "N97743\n", + "\n", + "Bool (97743)\n", + "False\n", + "\n", + "\n", + "\n", + "N97742->N97743\n", + "\n", + "\n", + "CREATE\n", + "result\n", + "\n", + "\n", + "\n", + "N97747\n", + "\n", + "Int (97747)\n", + "value: 63\n", + "\n", + "\n", + "\n", + "N97746->N97747\n", + "\n", + "\n", + "CREATE\n", + "result\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -699,7 +1631,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.11.0" }, "vscode": { "interpreter": { diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 664a73d8..4908b66e 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -56,10 +56,11 @@ def test_decorator_workfunction(decorated_add_multiply: Callable) -> None: assert wg.tasks["add_multiply1"].outputs["result"].value == 20 +@pytest.mark.usefixtures("started_daemon_client") def test_decorator_graph_builder(decorated_add_multiply_group: Callable) -> None: """Test graph build.""" wg = WorkGraph("test_graph_builder") - add1 = wg.add_task("AiiDAAdd", "add1", x=2, y=3, t=10) + add1 = wg.add_task("AiiDAAdd", "add1", x=2, y=3) add_multiply1 = wg.add_task(decorated_add_multiply_group, "add_multiply1", y=3, z=4) sum_diff1 = wg.add_task("AiiDASumDiff", "sum_diff1") wg.add_link(add1.outputs[0], add_multiply1.inputs["x"]) diff --git a/tests/test_while.py b/tests/test_while.py index e8afc8e2..20655b01 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -3,6 +3,39 @@ from aiida import orm +@pytest.mark.usefixtures("started_daemon_client") +def test_while_task(decorated_add, decorated_multiply, decorated_compare): + wg = WorkGraph("test_while_task") + # set a context variable before running. + wg.context = {"should_run": True} + add1 = wg.add_task(decorated_add, name="add1", x=1, y=1) + add1.set_context({"result": "n"}) + # --------------------------------------------------------------------- + add2 = wg.add_task(decorated_add, name="add2", x="{{n}}", y=1) + add2.wait.append("add1") + multiply1 = wg.add_task( + decorated_multiply, name="multiply1", x=add2.outputs["result"], y=2 + ) + # update the context variable + multiply1.set_context({"result": "n"}) + compare1 = wg.add_task( + decorated_compare, name="compare1", x=multiply1.outputs["result"], y=50 + ) + compare1.set_context({"result": "should_run"}) + wg.add_task( + "While", + max_iterations=100, + conditions=["should_run"], + tasks=["add2", "multiply1", "compare1"], + ) + # the `result` of compare1 taskis used as condition + # --------------------------------------------------------------------- + add3 = wg.add_task(decorated_add, name="add3", x=1, y=1) + wg.add_link(multiply1.outputs["result"], add3.inputs["x"]) + wg.submit(wait=True, timeout=100) + assert wg.tasks["add3"].outputs["result"].value == 63 + + @pytest.mark.usefixtures("started_daemon_client") def test_while(decorated_add, decorated_multiply, decorated_compare): # Create a WorkGraph will repeat itself based on the conditions diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index e5218447..679d3fbf 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -136,13 +136,14 @@ def test_pause_task_after_submit(wg_calcjob): assert wg.tasks["add2"].outputs["sum"].value == 9 +@pytest.mark.usefixtures("started_daemon_client") def test_workgraph_group_outputs(decorated_add): wg = WorkGraph("test_workgraph_group_outputs") wg.add_task(decorated_add, "add1", x=2, y=3) wg.group_outputs = [ {"name": "sum", "from": "add1.result"}, - {"name": "add1", "from": "add1"}, + # {"name": "add1", "from": "add1"}, ] wg.submit(wait=True) assert wg.process.outputs.sum.value == 5 - assert wg.process.outputs.add1.result.value == 5 + # assert wg.process.outputs.add1.result.value == 5 From b4e443f1c95d213b1f19d4ed15ef07f5726278a4 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 8 Aug 2024 23:41:46 +0200 Subject: [PATCH 04/11] Speed up unit test and skip unstable test (#205) * Allows the WorkGraph to wait for the particular states of a task. * Remove unused tasks and duplicated tests. * Replace the calcjob task with the calcfunction task. * Skip test for palying paused task, because it is unstable. --- .github/workflows/ci.yaml | 2 +- aiida_workgraph/engine/workgraph.py | 2 +- aiida_workgraph/workgraph.py | 26 +++++++----- tests/conftest.py | 36 +++++------------ tests/test_calcjob.py | 3 -- tests/test_engine.py | 23 ++++++----- tests/test_shell.py | 26 +----------- tests/test_tasks.py | 8 ++-- tests/test_while.py | 53 ++++++------------------ tests/test_workgraph.py | 63 ++++++++++++++++------------- 10 files changed, 95 insertions(+), 147 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cedc7ecd..b5deec9a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -99,7 +99,7 @@ jobs: env: AIIDA_WARN_v3: 1 run: | - pytest -v tests --cov + pytest -v tests --cov --durations=0 - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index 081eaac1..063dc5d7 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -838,7 +838,7 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None "PYTHONJOB", "SHELLJOB", ]: - if len(self._awaitables) > self.ctx.max_number_awaitables: + if len(self._awaitables) >= self.ctx.max_number_awaitables: print( MAX_NUMBER_AWAITABLES_MSG.format( self.ctx.max_number_awaitables, name diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index ca3bf8ce..71da6496 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -210,28 +210,36 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]: return wgdata - def wait(self, timeout: int = 50) -> None: + def wait(self, timeout: int = 50, tasks: dict = None) -> None: """ Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout. - Args: timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50. """ - - start = time.time() - self.update() - while self.state not in ( + terminating_states = ( "KILLED", "PAUSED", "FINISHED", "FAILED", "CANCELLED", "EXCEPTED", - ): - time.sleep(0.5) + ) + start = time.time() + self.update() + finished = False + while not finished: self.update() + if tasks is not None: + states = [] + for name, value in tasks.items(): + flag = self.tasks[name].state in value + states.append(flag) + finished = all(states) + else: + finished = self.state in terminating_states + time.sleep(0.5) if time.time() - start > timeout: - return + break def update(self) -> None: """ diff --git a/tests/conftest.py b/tests/conftest.py index 7c5f2927..a985fcbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from aiida_workgraph import task, WorkGraph from aiida.engine import calcfunction, workfunction -from aiida.orm import Float, Int, StructureData +from aiida.orm import Int, StructureData from aiida.calculations.arithmetic.add import ArithmeticAddCalculation from typing import Callable, Any, Union import time @@ -61,13 +61,9 @@ def wg_calcfunction() -> WorkGraph: """A workgraph with calcfunction.""" wg = WorkGraph(name="test_debug_math") - float1 = wg.add_task("AiiDANode", "float1", pk=Float(3.0).store().pk) - sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2) + sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2, y=3) sumdiff2 = wg.add_task("AiiDASumDiff", "sumdiff2", x=4) - sumdiff3 = wg.add_task("AiiDASumDiff", "sumdiff3", x=6) - wg.add_link(float1.outputs[0], sumdiff1.inputs[1]) wg.add_link(sumdiff1.outputs[0], sumdiff2.inputs[1]) - wg.add_link(sumdiff2.outputs[0], sumdiff3.inputs[1]) return wg @@ -78,17 +74,9 @@ def wg_calcjob(add_code) -> WorkGraph: print("add_code", add_code) wg = WorkGraph(name="test_debug_math") - int1 = wg.add_task("AiiDANode", "int1", pk=Int(3).store().pk) - code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk) - add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=Int(2).store()) - add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=Int(4).store()) - add3 = wg.add_task(ArithmeticAddCalculation, "add3", x=Int(4).store()) - wg.add_link(code1.outputs[0], add1.inputs["code"]) - wg.add_link(int1.outputs[0], add1.inputs["y"]) - wg.add_link(code1.outputs[0], add2.inputs["code"]) + add1 = wg.add_task(ArithmeticAddCalculation, "add1", x=2, y=3, code=add_code) + add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=4, code=add_code) wg.add_link(add1.outputs["sum"], add2.inputs["y"]) - wg.add_link(code1.outputs[0], add3.inputs["code"]) - wg.add_link(add2.outputs["sum"], add3.inputs["y"]) return wg @@ -245,17 +233,13 @@ def wg_structure_si() -> WorkGraph: def wg_engine(decorated_add, add_code) -> WorkGraph: """Use to test the engine.""" code = add_code - x = Int(2) wg = WorkGraph(name="test_run_order") - add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=x, y=Int(0), code=code) - add0.set({"metadata.options.sleep": 15}) - add1 = wg.add_task(decorated_add, "add1", x=x, y=Int(1), t=Int(1)) - add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=x, y=Int(2), code=code) - add2.set({"metadata.options.sleep": 1}) - add3 = wg.add_task(decorated_add, "add3", x=x, y=Int(3), t=Int(1)) - add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=x, y=Int(4), code=code) - add4.set({"metadata.options.sleep": 1}) - add5 = wg.add_task(decorated_add, "add5", x=x, y=Int(5), t=Int(1)) + add0 = wg.add_task(ArithmeticAddCalculation, "add0", x=2, y=0, code=code) + add1 = wg.add_task(decorated_add, "add1", x=2, y=1) + add2 = wg.add_task(ArithmeticAddCalculation, "add2", x=2, y=2, code=code) + add3 = wg.add_task(decorated_add, "add3", x=2, y=3) + add4 = wg.add_task(ArithmeticAddCalculation, "add4", x=2, y=4, code=code) + add5 = wg.add_task(decorated_add, "add5", x=2, y=5) wg.add_link(add0.outputs["sum"], add2.inputs["x"]) wg.add_link(add1.outputs[0], add3.inputs["x"]) wg.add_link(add3.outputs[0], add4.inputs["x"]) diff --git a/tests/test_calcjob.py b/tests/test_calcjob.py index 0a4adb95..74768f5f 100644 --- a/tests/test_calcjob.py +++ b/tests/test_calcjob.py @@ -1,6 +1,5 @@ import pytest from aiida_workgraph import WorkGraph -import os @pytest.mark.usefixtures("started_daemon_client") @@ -9,6 +8,4 @@ def test_submit(wg_calcjob: WorkGraph) -> None: wg = wg_calcjob wg.name = "test_submit_calcjob" wg.submit(wait=True) - os.system("verdi process list -a") - os.system(f"verdi process report {wg.pk}") assert wg.tasks["add2"].outputs["sum"].value == 9 diff --git a/tests/test_engine.py b/tests/test_engine.py index 50840421..2b322d7b 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,15 +1,19 @@ import time import pytest from aiida_workgraph import WorkGraph +from aiida.cmdline.utils.common import get_workchain_report @pytest.mark.usefixtures("started_daemon_client") -def test_run_order(wg_engine: WorkGraph) -> None: +def test_run_order(decorated_add) -> None: """Test the order. Tasks should run in parallel and only depend on the input tasks.""" - wg = wg_engine + wg = WorkGraph(name="test_run_order") + wg.add_task(decorated_add, "add0", x=2, y=0) + wg.add_task(decorated_add, "add1", x=2, y=1) wg.submit(wait=True) - wg.tasks["add2"].ctime < wg.tasks["add4"].ctime + report = get_workchain_report(wg.process, "REPORT") + assert "tasks ready to run: add0,add1" in report @pytest.mark.skip(reason="The test is not stable.") @@ -28,23 +32,20 @@ def test_reset_node(wg_engine: WorkGraph) -> None: assert len(wg.process.base.extras.get("_workgraph_queue")) == 1 -@pytest.mark.usefixtures("started_daemon_client") def test_max_number_jobs(add_code) -> None: from aiida_workgraph import WorkGraph from aiida.orm import Int from aiida.calculations.arithmetic.add import ArithmeticAddCalculation wg = WorkGraph("test_max_number_jobs") - N = 9 + N = 3 # Create N nodes for i in range(N): - temp = wg.add_task( + wg.add_task( ArithmeticAddCalculation, name=f"add{i}", x=Int(1), y=Int(1), code=add_code ) - # Set a sleep option for each job (e.g., 2 seconds per job) - temp.set({"metadata.options.sleep": 1}) - # Set the maximum number of running jobs inside the WorkGraph - wg.max_number_jobs = 3 + wg.max_number_jobs = 2 wg.submit(wait=True, timeout=100) - wg.tasks["add1"].ctime < wg.tasks["add8"].ctime + report = get_workchain_report(wg.process, "REPORT") + assert "tasks ready to run: add2" in report diff --git a/tests/test_shell.py b/tests/test_shell.py index 1f3cd431..7cfb758a 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -98,28 +98,6 @@ def parser(self, dirpath): {"identifier": "Any", "name": "result"} ], # add a "result" output socket from the parser ) - # echo result + y expression - job3 = wg.add_task( - "ShellJob", - name="job3", - command="echo", - arguments=["{result}", "*", "{z}"], - nodes={"result": job2.outputs["result"], "z": Int(4)}, - ) - # bc command to calculate the expression - job4 = wg.add_task( - "ShellJob", - name="job4", - command="bc", - arguments=["{expression}"], - nodes={"expression": job3.outputs["stdout"]}, - parser=PickledData(parser), - parser_outputs=[ - {"identifier": "Any", "name": "result"} - ], # add a "result" output socket from the parser - ) - # there is a bug in aiida-shell, the following line will raise an error - # https://github.com/sphuber/aiida-shell/issues/91 - # wg.submit(wait=True, timeout=200) + wg.run() - assert job4.outputs["result"].value.value == 20 + assert job2.outputs["result"].value.value == 5 diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 43e4be42..7402f3d6 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -10,8 +10,8 @@ def test_build_task_from_workgraph(wg_calcfunction, decorated_add): wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction") wg.add_task(decorated_add, name="add2", y=3) wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"]) - wg.add_link(wg_task.outputs["sumdiff3.sum"], wg.tasks["add2"].inputs["x"]) - assert len(wg_task.inputs) == 15 - assert len(wg_task.outputs) == 13 + wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"]) + assert len(wg_task.inputs) == 7 + assert len(wg_task.outputs) == 8 wg.submit(wait=True) - assert wg.tasks["add2"].outputs["result"].value.value == 20 + assert wg.tasks["add2"].outputs["result"].value.value == 14 diff --git a/tests/test_while.py b/tests/test_while.py index 20655b01..66c776d6 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -19,7 +19,7 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare): # update the context variable multiply1.set_context({"result": "n"}) compare1 = wg.add_task( - decorated_compare, name="compare1", x=multiply1.outputs["result"], y=50 + decorated_compare, name="compare1", x=multiply1.outputs["result"], y=30 ) compare1.set_context({"result": "should_run"}) wg.add_task( @@ -33,18 +33,17 @@ def test_while_task(decorated_add, decorated_multiply, decorated_compare): add3 = wg.add_task(decorated_add, name="add3", x=1, y=1) wg.add_link(multiply1.outputs["result"], add3.inputs["x"]) wg.submit(wait=True, timeout=100) - assert wg.tasks["add3"].outputs["result"].value == 63 + assert wg.tasks["add3"].outputs["result"].value == 31 -@pytest.mark.usefixtures("started_daemon_client") -def test_while(decorated_add, decorated_multiply, decorated_compare): +def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare): # Create a WorkGraph will repeat itself based on the conditions wg = WorkGraph("while_workgraph") wg.workgraph_type = "WHILE" wg.conditions = ["compare1.result"] wg.context = {"n": 1} wg.max_iteration = 10 - wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=50) + wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20) multiply1 = wg.add_task( decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) ) @@ -52,48 +51,20 @@ def test_while(decorated_add, decorated_multiply, decorated_compare): add1.set_context({"result": "n"}) wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) wg.submit(wait=True, timeout=100) - assert wg.execution_count == 4 - assert wg.tasks["add1"].outputs["result"].value == 61 + assert wg.execution_count == 3 + assert wg.tasks["add1"].outputs["result"].value == 29 +@pytest.mark.usefixtures("started_daemon_client") def test_while_graph_builder(decorated_add, decorated_multiply, decorated_compare): - # Create a WorkGraph will repeat itself based on the conditions - @task.graph_builder(outputs=[{"name": "result", "from": "context.n"}]) - def my_while(n=0, limit=100): - wg = WorkGraph("while_workgraph") - wg.workgraph_type = "WHILE" - wg.conditions = ["compare1.result"] - wg.context = {"n": n} - wg.max_iteration = 10 - wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit)) - multiply1 = wg.add_task( - decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) - ) - add1 = wg.add_task(decorated_add, name="add1", y=3) - add1.set_context({"result": "n"}) - wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) - return wg - - # ----------------------------------------- - wg = WorkGraph("while") - add1 = wg.add_task(decorated_add, name="add1", x=orm.Int(25), y=orm.Int(25)) - my_while1 = wg.add_task(my_while, n=orm.Int(1)) - add2 = wg.add_task(decorated_add, name="add2", y=orm.Int(2)) - wg.add_link(add1.outputs["result"], my_while1.inputs["limit"]) - wg.add_link(my_while1.outputs["result"], add2.inputs["x"]) - wg.submit(wait=True, timeout=100) - assert add2.outputs["result"].value == 63 - assert my_while1.node.outputs.execution_count == 4 - assert my_while1.outputs["result"].value == 61 - + """Test the while WorkGraph in graph builder. + Also test the max_iteration parameter.""" -def test_while_max_iteration(decorated_add, decorated_multiply, decorated_compare): - # Create a WorkGraph will repeat itself based on the conditions @task.graph_builder(outputs=[{"name": "result", "from": "context.n"}]) def my_while(n=0, limit=100): wg = WorkGraph("while_workgraph") wg.workgraph_type = "WHILE" - wg.max_iteration = 3 + wg.max_iteration = 2 wg.conditions = ["compare1.result"] wg.context = {"n": n} wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=orm.Int(limit)) @@ -113,5 +84,5 @@ def my_while(n=0, limit=100): wg.add_link(add1.outputs["result"], my_while1.inputs["limit"]) wg.add_link(my_while1.outputs["result"], add2.inputs["x"]) wg.submit(wait=True, timeout=100) - assert add2.outputs["result"].value < 63 - assert my_while1.node.outputs.execution_count == 3 + assert add2.outputs["result"].value < 31 + assert my_while1.node.outputs.execution_count == 2 diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 679d3fbf..2d262810 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -5,17 +5,17 @@ from aiida.calculations.arithmetic.add import ArithmeticAddCalculation -def test_to_dict(wg_calcjob): +def test_to_dict(wg_calcfunction): """Export NodeGraph to dict.""" - wg = wg_calcjob + wg = wg_calcfunction wgdata = wg.to_dict() assert len(wgdata["tasks"]) == len(wg.tasks) assert len(wgdata["links"]) == len(wg.links) -def test_from_dict(wg_calcjob): +def test_from_dict(wg_calcfunction): """Export NodeGraph to dict.""" - wg = wg_calcjob + wg = wg_calcfunction wgdata = wg.to_dict() wg1 = WorkGraph.from_dict(wgdata) assert len(wg.tasks) == len(wg1.tasks) @@ -32,9 +32,9 @@ def test_add_task(): assert len(wg.links) == 1 -def test_save_load(wg_calcjob): +def test_save_load(wg_calcfunction): """Save the workgraph""" - wg = wg_calcjob + wg = wg_calcfunction wg.name = "test_save_load" wg.save() assert wg.process.process_state.value.upper() == "CREATED" @@ -73,23 +73,24 @@ def test_reset_message(wg_calcjob): assert "Task add2 action: RESET." in report -def test_restart(wg_calcjob): +def test_restart(wg_calcfunction): """Restart from a finished workgraph. Load the workgraph, modify the task, and restart the workgraph. Only the modified node and its child tasks will be rerun.""" - wg = wg_calcjob + wg = wg_calcfunction + wg.add_task("AiiDASumDiff", "sumdiff3", x=4, y=wg.tasks["sumdiff2"].outputs["sum"]) wg.name = "test_restart_0" wg.submit(wait=True) wg1 = WorkGraph.load(wg.process.pk) wg1.restart() wg1.name = "test_restart_1" - wg1.tasks["add2"].set({"x": orm.Int(10).store()}) + wg1.tasks["sumdiff2"].set({"x": orm.Int(10).store()}) # wg1.save() wg1.submit(wait=True) - assert wg1.tasks["add1"].node.pk == wg.tasks["add1"].pk - assert wg1.tasks["add2"].node.pk != wg.tasks["add2"].pk - assert wg1.tasks["add3"].node.pk != wg.tasks["add3"].pk - assert wg1.tasks["add3"].node.outputs.sum == 19 + assert wg1.tasks["sumdiff1"].node.pk == wg.tasks["sumdiff1"].pk + assert wg1.tasks["sumdiff2"].node.pk != wg.tasks["sumdiff2"].pk + assert wg1.tasks["sumdiff3"].node.pk != wg.tasks["sumdiff3"].pk + assert wg1.tasks["sumdiff3"].node.outputs.sum == 19 def test_extend_workgraph(decorated_add_multiply_group): @@ -105,38 +106,46 @@ def test_extend_workgraph(decorated_add_multiply_group): assert wg.tasks["group_multiply1"].node.outputs.result == 45 +@pytest.mark.usefixtures("started_daemon_client") def test_pause_task_before_submit(wg_calcjob): wg = wg_calcjob wg.name = "test_pause_task" wg.pause_tasks(["add2"]) wg.submit() - time.sleep(20) - wg.update() + wg.wait(tasks={"add1": ["FINISHED"]}, timeout=20) + assert wg.tasks["add1"].node.process_state.value.upper() == "FINISHED" + # wait for the workgraph to launch add2 + wg.wait(tasks={"add2": ["CREATED"]}, timeout=20) assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" - wg.play_tasks(["add2"]) - wg.wait() - assert wg.tasks["add2"].outputs["sum"].value == 9 + # I disabled the following lines because the test is not stable + # Seems the daemon is not responding to the play signal + # This should be a problem of AiiDA test fixtures + # wg.play_tasks(["add2"]) + # wg.wait(tasks={"add2": ["FINISHED"]}) + # assert wg.tasks["add2"].outputs["sum"].value == 9 def test_pause_task_after_submit(wg_calcjob): wg = wg_calcjob + wg.tasks["add1"].set({"metadata.options.sleep": 3}) wg.name = "test_pause_task" wg.submit() - # wait for the daemon to start the workgraph - time.sleep(3) - # wg.run() + # wait for the workgraph to launch add1 + wg.wait(tasks={"add1": ["CREATED", "WAITING", "RUNNING", "FINISHED"]}, timeout=20) wg.pause_tasks(["add2"]) - time.sleep(20) - wg.update() + wg.wait(tasks={"add1": ["FINISHED"]}, timeout=20) + # wait for the workgraph to launch add2 + wg.wait(tasks={"add2": ["CREATED"]}, timeout=20) assert wg.tasks["add2"].node.process_state.value.upper() == "CREATED" assert wg.tasks["add2"].node.process_status == "Paused through WorkGraph" - wg.play_tasks(["add2"]) - wg.wait() - assert wg.tasks["add2"].outputs["sum"].value == 9 + # I disabled the following lines because the test is not stable + # Seems the daemon is not responding to the play signal + # wg.play_tasks(["add2"]) + # wg.wait(tasks={"add2": ["FINISHED"]}) + # assert wg.tasks["add2"].outputs["sum"].value == 9 -@pytest.mark.usefixtures("started_daemon_client") def test_workgraph_group_outputs(decorated_add): wg = WorkGraph("test_workgraph_group_outputs") wg.add_task(decorated_add, "add1", x=2, y=3) From 76047ee4c77b594193c0891dc4fd68c3012bf1f8 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Fri, 9 Aug 2024 10:27:48 +0200 Subject: [PATCH 05/11] Allow usage of decorators without parentheses (#199) The decorators that could be used as `@task()` can now be used as a nonfunctional `@task`. Same for the `task.calcfunction` and others. --- aiida_workgraph/decorator.py | 42 +++++++++- docs/source/concept/task.ipynb | 4 +- tests/test_decorator.py | 140 +++++++++++++++++++++++++++++++-- 3 files changed, 175 insertions(+), 11 deletions(-) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index dab8ebd8..2d2b56c7 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -416,6 +416,38 @@ def build_task_from_workgraph(wg: any) -> Task: return task +def nonfunctional_usage(callable: Callable): + """ + This is a decorator for a decorator factory (a function that returns a decorator). + It allows the usage of the decorator factory in a nonfunctional way. So a decorator + factory that has been decorated by this decorator that could only be used befor like + this + + .. code-block:: python + + @decorator_factory() + def foo(): + pass + + can now be also used like this + + .. code-block:: python + + @decorator_factory + def foo(): + pass + + """ + + def decorator_task_wrapper(*args, **kwargs): + if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0: + return callable()(args[0]) + else: + return callable(*args, **kwargs) + + return decorator_task_wrapper + + def generate_tdata( func: Callable, identifier: str, @@ -462,6 +494,7 @@ class TaskDecoratorCollection: # decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor @staticmethod + @nonfunctional_usage def decorator_task( identifier: Optional[str] = None, task_type: str = "Normal", @@ -511,6 +544,7 @@ def decorator(func): # decorator with arguments indentifier, args, kwargs, properties, inputs, outputs, executor @staticmethod + @nonfunctional_usage def decorator_graph_builder( identifier: Optional[str] = None, properties: Optional[List[Tuple[str, str]]] = None, @@ -561,6 +595,7 @@ def decorator(func): return decorator @staticmethod + @nonfunctional_usage def calcfunction(**kwargs: Any) -> Callable: def decorator(func): # First, apply the calcfunction decorator @@ -579,6 +614,7 @@ def decorator(func): return decorator @staticmethod + @nonfunctional_usage def workfunction(**kwargs: Any) -> Callable: def decorator(func): # First, apply the workfunction decorator @@ -597,6 +633,7 @@ def decorator(func): return decorator @staticmethod + @nonfunctional_usage def pythonjob(**kwargs: Any) -> Callable: def decorator(func): # first create a task from the function @@ -622,7 +659,10 @@ def decorator(func): def __call__(self, *args, **kwargs): # This allows using '@task' to directly apply the decorator_task functionality - return self.decorator_task(*args, **kwargs) + if len(args) == 1 and isinstance(args[0], Callable) and len(kwargs) == 0: + return self.decorator_task()(args[0]) + else: + return self.decorator_task(*args, **kwargs) task = TaskDecoratorCollection() diff --git a/docs/source/concept/task.ipynb b/docs/source/concept/task.ipynb index 3ea57cbc..9c3e40c3 100644 --- a/docs/source/concept/task.ipynb +++ b/docs/source/concept/task.ipynb @@ -45,12 +45,12 @@ "from aiida import orm\n", "\n", "# define add task\n", - "@task()\n", + "@task # this is equivalent to passing no arguments @task()\n", "def add(x, y):\n", " return x + y\n", "\n", "# define multiply calcfunction task\n", - "@task.calcfunction()\n", + "@task.calcfunction # this is equivalent to passing no arguments @task.calculation()\n", "def multiply(x, y):\n", " return orm.Float(x + y)\n", "\n", diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 4908b66e..83dba3af 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,25 +1,115 @@ import pytest from aiida_workgraph import WorkGraph from typing import Callable +from aiida_workgraph import task -def test_args() -> None: - from aiida_workgraph import task +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_calcfunction(request): + if request.param == "decorator_factory": - @task.calcfunction() - def test(a, b=1, **c): - print(a, b, c) + @task.calcfunction() + def test(a, b=1, **c): + print(a, b, c) + elif request.param == "decorator": + + @task.calcfunction + def test(a, b=1, **c): + print(a, b, c) + + else: + raise ValueError(f"{request.param} not supported.") + return test + + +def test_decorators_calcfunction_args(task_calcfunction) -> None: + metadata_kwargs = set( + [ + f"metadata.{key}" + for key in task_calcfunction.process_class.spec() + .inputs.ports["metadata"] + .ports.keys() + ] + ) + kwargs = set(task_calcfunction.process_class.spec().inputs.ports.keys()).union( + metadata_kwargs + ) + kwargs.remove("a") + # + n = task_calcfunction.task() + assert n.args == ["a"] + assert set(n.kwargs) == set(kwargs) + assert n.var_args is None + assert n.var_kwargs == "c" + assert n.outputs.keys() == ["result", "_outputs", "_wait"] + + +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_function(request): + if request.param == "decorator_factory": + + @task() + def test(a, b=1, **c): + print(a, b, c) + + elif request.param == "decorator": + + @task + def test(a, b=1, **c): + print(a, b, c) + + else: + raise ValueError(f"{request.param} not supported.") + return test + + +def test_decorators_task_args(task_function): + + tdata = task_function.tdata + assert tdata["args"] == ["a"] + assert tdata["kwargs"] == ["b"] + assert tdata["var_args"] is None + assert tdata["var_kwargs"] == "c" + assert set([output["name"] for output in tdata["outputs"]]) == set( + ["result", "_outputs", "_wait"] + ) + + +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_workfunction(request): + if request.param == "decorator_factory": + + @task.workfunction() + def test(a, b=1, **c): + print(a, b, c) + + elif request.param == "decorator": + + @task.workfunction + def test(a, b=1, **c): + print(a, b, c) + + else: + raise ValueError(f"{request.param} not supported.") + return test + + +def test_decorators_workfunction_args(task_workfunction) -> None: metadata_kwargs = set( [ f"metadata.{key}" - for key in test.process_class.spec().inputs.ports["metadata"].ports.keys() + for key in task_workfunction.process_class.spec() + .inputs.ports["metadata"] + .ports.keys() ] ) - kwargs = set(test.process_class.spec().inputs.ports.keys()).union(metadata_kwargs) + kwargs = set(task_workfunction.process_class.spec().inputs.ports.keys()).union( + metadata_kwargs + ) kwargs.remove("a") # - n = test.task() + n = task_workfunction.task() assert n.args == ["a"] assert set(n.kwargs) == set(kwargs) assert n.var_args is None @@ -27,6 +117,40 @@ def test(a, b=1, **c): assert n.outputs.keys() == ["result", "_outputs", "_wait"] +@pytest.fixture(params=["decorator_factory", "decorator"]) +def task_graph_builder(request): + if request.param == "decorator_factory": + + @task.graph_builder() + def add_multiply_group(a, b=1, **c): + wg = WorkGraph("add_multiply_group") + print(a, b, c) + return wg + + elif request.param == "decorator": + + @task.graph_builder + def add_multiply_group(a, b=1, **c): + wg = WorkGraph("add_multiply_group") + print(a, b, c) + return wg + + else: + raise ValueError(f"{request.param} not supported.") + + return add_multiply_group + + +def test_decorators_graph_builder_args(task_graph_builder) -> None: + assert task_graph_builder.identifier == "add_multiply_group" + n = task_graph_builder.task() + assert n.args == ["a"] + assert n.kwargs == ["b"] + assert n.var_args is None + assert n.var_kwargs == "c" + assert set(n.outputs.keys()) == set(["_outputs", "_wait"]) + + def test_inputs_outputs_workchain() -> None: from aiida.workflows.arithmetic.multiply_add import MultiplyAddWorkChain From b531057a18e393b10899f4a2988e0c871d376a5e Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Fri, 9 Aug 2024 22:22:54 +0200 Subject: [PATCH 06/11] Register an entry point for each class (Task, Socket, Property) (#209) Register an entry point for each class, including Task, Socket and Property instead of a list of classes. Suggested naming rule for entry points: - Format: {package_name}.{short_name_for_class}. If the plugin name starts with 'aiida', such as 'aiida-xxx', you may omit 'aiida' for brevity, e.g., 'xxx.add'. Aim to maintain consistency with the name used when registering with the AiiDA registry. - The name is case-insensitive, thus `xxx.add` is the same as `xxx.Add`. - Naming Style: Use snake case. For instance, use 'xxx.test_add' for the class 'TestAddTask'. Typically, the suffix 'task' is unnecessary, thus avoid using 'xxx.test_add_task'. --- aiida_workgraph/calculations/python_parser.py | 7 +- aiida_workgraph/collection.py | 3 +- aiida_workgraph/decorator.py | 91 ++++--- aiida_workgraph/executors/qe.py | 2 +- aiida_workgraph/properties/__init__.py | 5 +- .../properties/{built_in.py => builtins.py} | 85 ++---- aiida_workgraph/sockets/__init__.py | 5 +- .../sockets/{built_in.py => builtins.py} | 60 +---- aiida_workgraph/tasks/__init__.py | 46 +--- aiida_workgraph/tasks/builtin.py | 90 ------- aiida_workgraph/tasks/builtins.py | 243 ++++++++++++++++++ aiida_workgraph/tasks/qe.py | 75 ------ aiida_workgraph/tasks/test.py | 235 +++-------------- docs/source/built-in/pythonjob.ipynb | 14 +- docs/source/development/custom_task.rst | 12 + docs/source/development/index.rst | 20 +- docs/source/development/test.rst | 18 ++ pyproject.toml | 41 ++- tests/conftest.py | 33 +-- tests/datas/test_calcfunction.yaml | 10 +- tests/datas/test_calcjob.yaml | 12 +- tests/test_build_task.py | 4 +- tests/test_decorator.py | 4 +- tests/test_link.py | 8 +- tests/test_python.py | 15 +- tests/test_shell.py | 2 +- tests/test_socket.py | 3 +- tests/test_workchain.py | 2 +- tests/test_workgraph.py | 9 +- 29 files changed, 511 insertions(+), 643 deletions(-) rename aiida_workgraph/properties/{built_in.py => builtins.py} (84%) rename aiida_workgraph/sockets/{built_in.py => builtins.py} (72%) delete mode 100644 aiida_workgraph/tasks/builtin.py create mode 100644 aiida_workgraph/tasks/builtins.py delete mode 100644 aiida_workgraph/tasks/qe.py create mode 100644 docs/source/development/custom_task.rst create mode 100644 docs/source/development/test.rst diff --git a/aiida_workgraph/calculations/python_parser.py b/aiida_workgraph/calculations/python_parser.py index 86b60792..6c02c4c3 100644 --- a/aiida_workgraph/calculations/python_parser.py +++ b/aiida_workgraph/calculations/python_parser.py @@ -11,7 +11,7 @@ def parse(self, **kwargs): The outputs could be a namespce, e.g., outputs=[ - {"identifier": "Namespace", "name": "add_multiply"}, + {"identifier": "workgraph.namespace", "name": "add_multiply"}, {"name": "add_multiply.add"}, {"name": "add_multiply.multiply"}, {"name": "minus"}, @@ -100,7 +100,7 @@ def serialize_output(self, result, output): """Serialize outputs.""" name = output["name"] - if output["identifier"].upper() == "NAMESPACE": + if output["identifier"].upper() == "WORKGRAPH.NAMESPACE": if isinstance(result, dict): serialized_result = {} for key, value in result.items(): @@ -108,7 +108,8 @@ def serialize_output(self, result, output): full_name_output = self.find_output(full_name) if ( full_name_output - and full_name_output["identifier"].upper() == "NAMESPACE" + and full_name_output["identifier"].upper() + == "WORKGRAPH.NAMESPACE" ): serialized_result[key] = self.serialize_output( value, full_name_output diff --git a/aiida_workgraph/collection.py b/aiida_workgraph/collection.py index 4104617f..6e826bf0 100644 --- a/aiida_workgraph/collection.py +++ b/aiida_workgraph/collection.py @@ -39,7 +39,7 @@ def new( task.set(links) return task if isinstance(identifier, str) and identifier.upper() == "WHILE": - task = super().new(identifier, name, uuid, **kwargs) + task = super().new("workgraph.while", name, uuid, **kwargs) return task if isinstance(identifier, WorkGraph): identifier = build_task_from_workgraph(identifier) @@ -89,6 +89,7 @@ def new( # build the socket on the fly if the identifier is a callable if callable(identifier): + print("identifier is callable", identifier) identifier = build_socket_from_AiiDA(identifier) # Call the original new method return super().new(identifier, name, **kwargs) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 2d2b56c7..95a21655 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -15,11 +15,11 @@ WorkChain: "WORKCHAIN", } -aiida_socket_maping = { - orm.Int: "AiiDAInt", - orm.Float: "AiiDAFloat", - orm.Str: "AiiDAString", - orm.Bool: "AiiDABool", +aiida_socket_mapping = { + orm.Int: "workgraph.aiida_int", + orm.Float: "workgraph.aiida_float", + orm.Str: "workgraph.aiida_str", + orm.Bool: "workgraph.aiida_bool", } @@ -53,9 +53,9 @@ def add_input_recursive( if port_name not in input_names: inputs.append( { - "identifier": "Namespace", + "identifier": "workgraph.namespace", "name": port_name, - "property": {"identifier": "Any", "default": {}}, + "property": {"identifier": "workgraph.any", "default": {}}, } ) if required: @@ -71,11 +71,13 @@ def add_input_recursive( # port.valid_type can be a single type or a tuple of types, # we only support single type for now if isinstance(port.valid_type, tuple) and len(port.valid_type) > 1: - socket_type = "Any" + socket_type = "workgraph.any" if isinstance(port.valid_type, tuple) and len(port.valid_type) == 1: - socket_type = aiida_socket_maping.get(port.valid_type[0], "Any") + socket_type = aiida_socket_mapping.get( + port.valid_type[0], "workgraph.any" + ) else: - socket_type = aiida_socket_maping.get(port.valid_type, "Any") + socket_type = aiida_socket_mapping.get(port.valid_type, "workgraph.any") inputs.append({"identifier": socket_type, "name": port_name}) if required: args.append(port_name) @@ -102,12 +104,12 @@ def add_output_recursive( # so if you change the value of one port, the value of all the ports of other tasks will be changed # consider to use None as default value if port_name not in output_names: - outputs.append({"identifier": "Namespace", "name": port_name}) + outputs.append({"identifier": "workgraph.namespace", "name": port_name}) for value in port.values(): add_output_recursive(outputs, value, prefix=port_name, required=required) else: if port_name not in output_names: - outputs.append({"identifier": "Any", "name": port_name}) + outputs.append({"identifier": "workgraph.any", "name": port_name}) return outputs @@ -219,9 +221,9 @@ def build_task_from_AiiDA( tdata["var_kwargs"] = name inputs.append( { - "identifier": "Any", + "identifier": "workgraph.any", "name": name, - "property": {"identifier": "Any", "default": {}}, + "property": {"identifier": "workgraph.any", "default": {}}, } ) # TODO In order to reload the WorkGraph from process, "is_pickle" should be True @@ -234,15 +236,19 @@ def build_task_from_AiiDA( "is_pickle": True, } if tdata["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]: - outputs = [{"identifier": "Any", "name": "result"}] if not outputs else outputs + outputs = ( + [{"identifier": "workgraph.any", "name": "result"}] + if not outputs + else outputs + ) # get the source code of the function tdata["executor"] = serialize_function(executor) # tdata["executor"]["type"] = tdata["task_type"] # print("kwargs: ", kwargs) # add built-in sockets - outputs.append({"identifier": "Any", "name": "_outputs"}) - outputs.append({"identifier": "Any", "name": "_wait"}) - inputs.append({"identifier": "Any", "name": "_wait", "link_limit": 1e6}) + outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) + outputs.append({"identifier": "workgraph.any", "name": "_wait"}) + inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6}) tdata["node_class"] = Task tdata["args"] = args tdata["kwargs"] = kwargs @@ -273,10 +279,10 @@ def build_pythonjob_task(func: Callable) -> Task: inputs = tdata["inputs"] inputs.extend( [ - {"identifier": "String", "name": "computer"}, - {"identifier": "String", "name": "code_label"}, - {"identifier": "String", "name": "code_path"}, - {"identifier": "String", "name": "prepend_text"}, + {"identifier": "workgraph.string", "name": "computer"}, + {"identifier": "workgraph.string", "name": "code_label"}, + {"identifier": "workgraph.string", "name": "code_path"}, + {"identifier": "workgraph.string", "name": "prepend_text"}, ] ) outputs = tdata["outputs"] @@ -319,7 +325,7 @@ def build_shelljob_task( nodes = {} if nodes is None else nodes keys = list(nodes.keys()) for key in keys: - inputs.append({"identifier": "Any", "name": f"nodes.{key}"}) + inputs.append({"identifier": "workgraph.any", "name": f"nodes.{key}"}) # input is a output of another task, we make a link if isinstance(nodes[key], NodeSocket): links[f"nodes.{key}"] = nodes[key] @@ -332,14 +338,14 @@ def build_shelljob_task( # Extend the outputs tdata["outputs"].extend( [ - {"identifier": "Any", "name": "stdout"}, - {"identifier": "Any", "name": "stderr"}, + {"identifier": "workgraph.any", "name": "stdout"}, + {"identifier": "workgraph.any", "name": "stderr"}, ] ) outputs = [] if outputs is None else outputs parser_outputs = [] if parser_outputs is None else parser_outputs outputs = [ - {"identifier": "Any", "name": ShellParser.format_link_label(output)} + {"identifier": "workgraph.any", "name": ShellParser.format_link_label(output)} for output in outputs ] outputs.extend(parser_outputs) @@ -351,8 +357,8 @@ def build_shelljob_task( tdata["identifier"] = "ShellJob" tdata["inputs"].extend( [ - {"identifier": "Any", "name": "command"}, - {"identifier": "Any", "name": "resolve_command"}, + {"identifier": "workgraph.any", "name": "command"}, + {"identifier": "workgraph.any", "name": "resolve_command"}, ] ) tdata["kwargs"].extend(["command", "resolve_command"]) @@ -374,17 +380,21 @@ def build_task_from_workgraph(wg: any) -> Task: # add all the inputs/outputs from the tasks in the workgraph for task in wg.tasks: # inputs - inputs.append({"identifier": "Any", "name": f"{task.name}"}) + inputs.append({"identifier": "workgraph.any", "name": f"{task.name}"}) for socket in task.inputs: if socket.name == "_wait": continue - inputs.append({"identifier": "Any", "name": f"{task.name}.{socket.name}"}) + inputs.append( + {"identifier": "workgraph.any", "name": f"{task.name}.{socket.name}"} + ) # outputs - outputs.append({"identifier": "Any", "name": f"{task.name}"}) + outputs.append({"identifier": "workgraph.any", "name": f"{task.name}"}) for socket in task.outputs: if socket.name in ["_wait", "_outputs"]: continue - outputs.append({"identifier": "Any", "name": f"{task.name}.{socket.name}"}) + outputs.append( + {"identifier": "workgraph.any", "name": f"{task.name}.{socket.name}"} + ) group_outputs.append( { "name": f"{task.name}.{socket.name}", @@ -393,9 +403,9 @@ def build_task_from_workgraph(wg: any) -> Task: ) kwargs = [input["name"] for input in inputs] # add built-in sockets - outputs.append({"identifier": "Any", "name": "_outputs"}) - outputs.append({"identifier": "Any", "name": "_wait"}) - inputs.append({"identifier": "Any", "name": "_wait", "link_limit": 1e6}) + outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) + outputs.append({"identifier": "workgraph.any", "name": "_wait"}) + inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6}) tdata["node_class"] = Task tdata["kwargs"] = kwargs tdata["inputs"] = inputs @@ -467,9 +477,9 @@ def generate_tdata( ) task_outputs = outputs # add built-in sockets - _inputs.append({"identifier": "Any", "name": "_wait", "link_limit": 1e6}) - task_outputs.append({"identifier": "Any", "name": "_wait"}) - task_outputs.append({"identifier": "Any", "name": "_outputs"}) + _inputs.append({"identifier": "workgraph.any", "name": "_wait", "link_limit": 1e6}) + task_outputs.append({"identifier": "workgraph.any", "name": "_wait"}) + task_outputs.append({"identifier": "workgraph.any", "name": "_outputs"}) tdata = { "node_class": Task, "identifier": identifier, @@ -529,7 +539,7 @@ def decorator(func): func, identifier, inputs or [], - outputs or [{"identifier": "Any", "name": "result"}], + outputs or [{"identifier": "workgraph.any", "name": "result"}], properties or [], catalog, task_type, @@ -572,7 +582,8 @@ def decorator(func): func.identifier = identifier task_outputs = [ - {"identifier": "Any", "name": output["name"]} for output in outputs + {"identifier": "workgraph.any", "name": output["name"]} + for output in outputs ] # print(task_inputs, task_outputs) # diff --git a/aiida_workgraph/executors/qe.py b/aiida_workgraph/executors/qe.py index 275a916e..e803ba17 100644 --- a/aiida_workgraph/executors/qe.py +++ b/aiida_workgraph/executors/qe.py @@ -5,7 +5,7 @@ @task( inputs=[ - {"identifier": "String", "name": "pseudo_family"}, + {"identifier": "workgraph.string", "name": "pseudo_family"}, {"identifier": StructureData, "name": "structure"}, ], outputs=[{"identifier": UpfData, "name": "Pseudo"}], diff --git a/aiida_workgraph/properties/__init__.py b/aiida_workgraph/properties/__init__.py index 895fa725..199f9f06 100644 --- a/aiida_workgraph/properties/__init__.py +++ b/aiida_workgraph/properties/__init__.py @@ -1,3 +1,6 @@ from node_graph.utils import get_entries -property_pool = get_entries(entry_point_name="aiida_workgraph.property") +property_pool = { + **get_entries(entry_point_name="node_graph.property"), + **get_entries(entry_point_name="aiida_workgraph.property"), +} diff --git a/aiida_workgraph/properties/built_in.py b/aiida_workgraph/properties/builtins.py similarity index 84% rename from aiida_workgraph/properties/built_in.py rename to aiida_workgraph/properties/builtins.py index 7306fc84..12f8111c 100644 --- a/aiida_workgraph/properties/built_in.py +++ b/aiida_workgraph/properties/builtins.py @@ -1,38 +1,14 @@ from typing import Dict, List, Union, Callable from node_graph.property import NodeProperty -from node_graph.serializer import SerializeJson, SerializePickle -from node_graph.properties.builtin import ( - VectorProperty, - BaseDictProperty, - BaseListProperty, - IntProperty, - BoolProperty, - FloatProperty, - StringProperty, -) +from node_graph.serializer import SerializeJson +from node_graph.properties.builtins import PropertyVector, PropertyAny from aiida import orm -class AnyProperty(NodeProperty, SerializePickle): - """A new class for Any type.""" - - identifier: str = "Any" - data_type = "Any" - - def __init__( - self, - name: str, - description: str = "", - default: Union[int, str, None] = None, - update: Callable = None, - ) -> None: - super().__init__(name, description, default, update) - - -class AiiDAIntProperty(NodeProperty, SerializeJson): +class PropertyAiiDAInt(NodeProperty, SerializeJson): """A new class for integer type.""" - identifier: str = "AiiDAInt" + identifier: str = "workgraph.aiida_int" data_type = "Int" def __init__( @@ -68,10 +44,10 @@ def set_value(self, value: Union[int, orm.Int, str]) -> None: raise Exception("{} is not a integer.".format(value)) -class AiiDAFloatProperty(NodeProperty, SerializeJson): +class PropertyAiiDAFloat(NodeProperty, SerializeJson): """A new class for float type.""" - identifier: str = "AiiDAFloat" + identifier: str = "workgraph.aiida_float" data_type = "Float" def __init__( @@ -107,10 +83,10 @@ def set_value(self, value: Union[float, orm.Float, int, orm.Int, str]) -> None: raise Exception("{} is not a float.".format(value)) -class AiiDABoolProperty(NodeProperty, SerializeJson): +class PropertyAiiDABool(NodeProperty, SerializeJson): """A new class for bool type.""" - identifier: str = "AiiDABool" + identifier: str = "workgraph.aiida_bool" data_type = "Bool" def __init__( @@ -146,10 +122,10 @@ def set_value(self, value: Union[bool, orm.Bool, int, str]) -> None: raise Exception("{} is not a bool.".format(value)) -class AiiDAStringProperty(NodeProperty, SerializeJson): +class PropertyAiiDAString(NodeProperty, SerializeJson): """A new class for string type.""" - identifier: str = "AiiDAString" + identifier: str = "workgraph.aiida_string" data_type = "String" def __init__( @@ -184,10 +160,10 @@ def set_value(self, value: Union[str, orm.Str]) -> None: raise Exception("{} is not a string.".format(value)) -class AiiDADictProperty(NodeProperty, SerializeJson): +class PropertyAiiDADict(NodeProperty, SerializeJson): """A new class for Dict type.""" - identifier: str = "AiiDADict" + identifier: str = "workgraph.aiida_dict" data_type = "Dict" def __init__( @@ -222,10 +198,10 @@ def set_value(self, value: Union[Dict, orm.Dict, str]) -> None: raise Exception("{} is not a dict.".format(value)) -class AiiDAIntVectorProperty(VectorProperty): +class PropertyAiiDAIntVector(PropertyVector): """A new class for integer vector type.""" - identifier: str = "AiiDAIntVector" + identifier: str = "workgraph.aiida_int_vector" data_type = "AiiDAIntVector" def __init__( @@ -260,10 +236,10 @@ def set_value(self, value: List[int]) -> None: ) -class AiiDAFloatVectorProperty(VectorProperty): +class PropertyAiiDAFloatVector(PropertyVector): """A new class for float vector type.""" - identifier: str = "AiiDAFloatVector" + identifier: str = "workgraph.aiida_float_vector" data_type = "AiiDAFloatVector" def __init__( @@ -304,10 +280,10 @@ def get_metadata(self): # Vector -class BoolVectorProperty(VectorProperty): +class PropertyBoolVector(PropertyVector): """A new class for bool vector type.""" - identifier: str = "BoolVector" + identifier: str = "workgraph.bool_vector" data_type = "BoolVector" def __init__( @@ -340,19 +316,14 @@ def set_value(self, value: List[Union[bool, int]]) -> None: ) -property_list = [ - IntProperty, - FloatProperty, - BoolProperty, - StringProperty, - AnyProperty, - BaseDictProperty, - BaseListProperty, - AiiDAIntProperty, - AiiDAFloatProperty, - AiiDAStringProperty, - AiiDABoolProperty, - AiiDADictProperty, - AiiDAIntVectorProperty, - AiiDAFloatVectorProperty, +__all__ = [ + PropertyAny, + PropertyAiiDAInt, + PropertyAiiDAFloat, + PropertyAiiDABool, + PropertyAiiDAString, + PropertyAiiDADict, + PropertyAiiDAIntVector, + PropertyAiiDAFloatVector, + PropertyBoolVector, ] diff --git a/aiida_workgraph/sockets/__init__.py b/aiida_workgraph/sockets/__init__.py index 4a661495..82b2a9b8 100644 --- a/aiida_workgraph/sockets/__init__.py +++ b/aiida_workgraph/sockets/__init__.py @@ -1,3 +1,6 @@ from node_graph.utils import get_entries -socket_pool = get_entries(entry_point_name="aiida_workgraph.socket") +socket_pool = { + **get_entries(entry_point_name="node_graph.socket"), + **get_entries(entry_point_name="aiida_workgraph.socket"), +} diff --git a/aiida_workgraph/sockets/built_in.py b/aiida_workgraph/sockets/builtins.py similarity index 72% rename from aiida_workgraph/sockets/built_in.py rename to aiida_workgraph/sockets/builtins.py index 2303fe1e..2401cf6e 100644 --- a/aiida_workgraph/sockets/built_in.py +++ b/aiida_workgraph/sockets/builtins.py @@ -1,38 +1,24 @@ from typing import Optional, Any from aiida_workgraph.socket import TaskSocket from node_graph.serializer import SerializeJson, SerializePickle -from node_graph.sockets.builtin import ( - SocketInt, - SocketFloat, - SocketString, - SocketBool, - SocketBaseDict, - SocketBaseList, -) class SocketAny(TaskSocket, SerializePickle): - """Socket for any time.""" + """Any socket.""" - identifier: str = "Any" + identifier: str = "workgraph.any" def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any + self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs ) -> None: super().__init__(name, node, type, index, uuid=uuid) - self.add_property("Any", name, **kwargs) + self.add_property("workgraph.any", name, **kwargs) class SocketNamespace(TaskSocket, SerializePickle): """Namespace socket.""" - identifier: str = "Namespace" + identifier: str = "workgraph.namespace" def __init__( self, @@ -44,13 +30,13 @@ def __init__( **kwargs: Any ) -> None: super().__init__(name, node, type, index, uuid=uuid) - self.add_property("Any", name, **kwargs) + self.add_property("workgraph.any", name, **kwargs) class SocketAiiDAFloat(TaskSocket, SerializeJson): """AiiDAFloat socket.""" - identifier: str = "AiiDAFloat" + identifier: str = "workgraph.aiida_float" def __init__( self, @@ -62,13 +48,13 @@ def __init__( **kwargs: Any ) -> None: super().__init__(name, node, type, index, uuid=uuid) - self.add_property("AiiDAFloat", name, **kwargs) + self.add_property("workgraph.aiida_float", name, **kwargs) class SocketAiiDAInt(TaskSocket, SerializeJson): """AiiDAInt socket.""" - identifier: str = "AiiDAInt" + identifier: str = "workgraph.aiida_int" def __init__( self, @@ -80,13 +66,13 @@ def __init__( **kwargs: Any ) -> None: super().__init__(name, node, type, index, uuid=uuid) - self.add_property("AiiDAInt", name, **kwargs) + self.add_property("workgraph.aiida_int", name, **kwargs) class SocketAiiDAString(TaskSocket, SerializeJson): """AiiDAString socket.""" - identifier: str = "AiiDAString" + identifier: str = "workgraph.aiida_string" def __init__( self, @@ -104,7 +90,7 @@ def __init__( class SocketAiiDABool(TaskSocket, SerializeJson): """AiiDABool socket.""" - identifier: str = "AiiDABool" + identifier: str = "workgraph.aiida_bool" def __init__( self, @@ -122,7 +108,7 @@ def __init__( class SocketAiiDAIntVector(TaskSocket, SerializeJson): """Socket with a AiiDAIntVector property.""" - identifier: str = "AiiDAIntVector" + identifier: str = "workgraph.aiida_int_vector" def __init__( self, @@ -140,7 +126,7 @@ def __init__( class SocketAiiDAFloatVector(TaskSocket, SerializeJson): """Socket with a FloatVector property.""" - identifier: str = "FloatVector" + identifier: str = "workgraph.aiida_float_vector" def __init__( self, @@ -153,21 +139,3 @@ def __init__( ) -> None: super().__init__(name, node, type, index, uuid=uuid) self.add_property("FloatVector", name, **kwargs) - - -socket_list = [ - SocketAny, - SocketNamespace, - SocketInt, - SocketFloat, - SocketString, - SocketBool, - SocketBaseDict, - SocketBaseList, - SocketAiiDAInt, - SocketAiiDAFloat, - SocketAiiDAString, - SocketAiiDABool, - SocketAiiDAIntVector, - SocketAiiDAFloatVector, -] diff --git a/aiida_workgraph/tasks/__init__.py b/aiida_workgraph/tasks/__init__.py index ef890f2a..9230d873 100644 --- a/aiida_workgraph/tasks/__init__.py +++ b/aiida_workgraph/tasks/__init__.py @@ -1,45 +1,7 @@ from node_graph.utils import get_entries -from .builtin import AiiDAGather, AiiDAToCtx, AiiDAFromCtx, While -from .test import ( - AiiDAInt, - AiiDAFloat, - AiiDAString, - AiiDAList, - AiiDADict, - AiiDANode, - AiiDACode, - AiiDAAdd, - AiiDAGreater, - AiiDASumDiff, - AiiDAArithmeticMultiplyAdd, -) -from .qe import ( - AiiDAKpoint, - AiiDAPWPseudo, - AiiDAStructure, -) - -task_list = [ - While, - AiiDAGather, - AiiDAToCtx, - AiiDAFromCtx, - AiiDAInt, - AiiDAFloat, - AiiDAString, - AiiDAList, - AiiDADict, - AiiDANode, - AiiDACode, - AiiDAAdd, - AiiDAGreater, - AiiDASumDiff, - AiiDAArithmeticMultiplyAdd, - AiiDAKpoint, - AiiDAPWPseudo, - AiiDAStructure, -] - # should after task_list, otherwise circular import -task_pool = get_entries(entry_point_name="aiida_workgraph.task") +task_pool = { + **get_entries(entry_point_name="aiida_workgraph.task"), + **get_entries(entry_point_name="node_graph.task"), +} diff --git a/aiida_workgraph/tasks/builtin.py b/aiida_workgraph/tasks/builtin.py deleted file mode 100644 index b4a29d44..00000000 --- a/aiida_workgraph/tasks/builtin.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Dict -from aiida_workgraph.task import Task - - -class While(Task): - """While""" - - identifier = "While" - name = "While" - node_type = "WHILE" - catalog = "Control" - kwargs = ["max_iterations", "conditions", "tasks"] - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - inp = self.inputs.new("Any", "_wait") - inp.link_limit = 100000 - self.inputs.new("Int", "max_iterations") - self.inputs.new("Any", "tasks") - self.inputs.new("Any", "conditions") - self.outputs.new("Any", "_wait") - - -class AiiDAGather(Task): - """AiiDAGather""" - - identifier = "AiiDAGather" - name = "AiiDAGather" - node_type = "WORKCHAIN" - catalog = "AiiDA" - kwargs = ["datas"] - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - inp = self.inputs.new("Any", "datas") - inp.link_limit = 100000 - self.outputs.new("Any", "result") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida_workgraph.executors.builtin", - "name": "GatherWorkChain", - } - - -class AiiDAToCtx(Task): - """AiiDAToCtx""" - - identifier = "ToCtx" - name = "ToCtx" - node_type = "Control" - catalog = "AiiDA" - args = ["key", "value"] - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("Any", "key") - self.inputs.new("Any", "value") - self.outputs.new("Any", "result") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "builtins", - "name": "setattr", - } - - -class AiiDAFromCtx(Task): - """AiiDAFromCtx""" - - identifier = "FromCtx" - name = "FromCtx" - node_type = "Control" - catalog = "AiiDA" - args = ["key"] - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("Any", "key") - self.outputs.new("Any", "result") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "builtins", - "name": "getattr", - } diff --git a/aiida_workgraph/tasks/builtins.py b/aiida_workgraph/tasks/builtins.py new file mode 100644 index 00000000..86bab6f6 --- /dev/null +++ b/aiida_workgraph/tasks/builtins.py @@ -0,0 +1,243 @@ +from typing import Dict +from aiida_workgraph.task import Task + + +class While(Task): + """While""" + + identifier = "workgraph.while" + name = "While" + node_type = "WHILE" + catalog = "Control" + kwargs = ["max_iterations", "conditions", "tasks"] + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + inp = self.inputs.new("workgraph.any", "_wait") + inp.link_limit = 100000 + self.inputs.new("node_graph.int", "max_iterations") + self.inputs.new("workgraph.any", "tasks") + self.inputs.new("workgraph.any", "conditions") + self.outputs.new("workgraph.any", "_wait") + + +class Gather(Task): + """Gather""" + + identifier = "workgraph.aiida_gather" + name = "Gather" + node_type = "WORKCHAIN" + catalog = "Control" + kwargs = ["datas"] + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + inp = self.inputs.new("workgraph.any", "datas") + inp.link_limit = 100000 + self.outputs.new("workgraph.any", "result") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida_workgraph.executors.builtin", + "name": "GatherWorkChain", + } + + +class ToCtx(Task): + """ToCtx""" + + identifier = "workgraph.to_ctx" + name = "ToCtx" + node_type = "Control" + catalog = "Control" + args = ["key", "value"] + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + self.inputs.new("workgraph.any", "key") + self.inputs.new("workgraph.any", "value") + self.outputs.new("workgraph.any", "result") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "builtins", + "name": "setattr", + } + + +class FromCtx(Task): + """FromCtx""" + + identifier = "workgraph.from_ctx" + name = "FromCtx" + node_type = "Control" + catalog = "Control" + args = ["key"] + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + self.inputs.new("workgraph.any", "key") + self.outputs.new("workgraph.any", "result") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "builtins", + "name": "getattr", + } + + +class AiiDAInt(Task): + identifier = "workgraph.aiida_int" + name = "AiiDAInt" + node_type = "data" + catalog = "Test" + + args = ["value"] + + def create_sockets(self) -> None: + inp = self.inputs.new("workgraph.any", "value", default=0.0) + inp.add_property("workgraph.aiida_int", default=1.0) + self.outputs.new("workgraph.aiida_int", "result") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "Int", + } + + +class AiiDAFloat(Task): + identifier = "workgraph.aiida_float" + name = "AiiDAFloat" + node_type = "data" + catalog = "Test" + + args = ["value"] + + def create_sockets(self) -> None: + self.inputs.new("workgraph.aiida_float", "value", default=0.0) + self.outputs.new("workgraph.aiida_float", "result") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "Float", + } + + +class AiiDAString(Task): + identifier = "workgraph.aiida_string" + name = "AiiDAString" + node_type = "data" + catalog = "Test" + + args = ["value"] + + def create_sockets(self) -> None: + self.inputs.new("AiiDAString", "value", default="") + self.outputs.new("AiiDAString", "result") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "Str", + } + + +class AiiDAList(Task): + identifier = "workgraph.aiida_list" + name = "AiiDAList" + node_type = "data" + catalog = "Test" + + args = ["value"] + + def create_properties(self) -> None: + self.properties.new("BaseList", "value", default=[]) + + def create_sockets(self) -> None: + self.outputs.new("workgraph.any", "Parameters") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "List", + } + + +class AiiDADict(Task): + identifier = "workgraph.aiida_dict" + name = "AiiDADict" + node_type = "data" + catalog = "Test" + + args = ["value"] + + def create_properties(self) -> None: + self.properties.new("BaseDict", "value", default={}) + + def create_sockets(self) -> None: + self.outputs.new("workgraph.any", "Parameters") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "Dict", + } + + +class AiiDANode(Task): + """AiiDANode""" + + identifier = "workgraph.aiida_node" + name = "AiiDANode" + node_type = "node" + catalog = "Test" + kwargs = ["identifier", "pk", "uuid", "label"] + + def create_properties(self) -> None: + pass + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + self.inputs.new("workgraph.any", "identifier") + self.inputs.new("workgraph.any", "pk") + self.inputs.new("workgraph.any", "uuid") + self.inputs.new("workgraph.any", "label") + self.outputs.new("workgraph.any", "node") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "load_node", + } + + +class AiiDACode(Task): + """AiiDACode""" + + identifier = "workgraph.aiida_code" + name = "AiiDACode" + node_type = "node" + catalog = "Test" + kwargs = ["identifier", "pk", "uuid", "label"] + + def create_sockets(self) -> None: + self.inputs.clear() + self.outputs.clear() + self.inputs.new("workgraph.any", "identifier") + self.inputs.new("workgraph.any", "pk") + self.inputs.new("workgraph.any", "uuid") + self.inputs.new("workgraph.any", "label") + self.outputs.new("workgraph.any", "Code") + + def get_executor(self) -> Dict[str, str]: + return { + "path": "aiida.orm", + "name": "load_code", + } diff --git a/aiida_workgraph/tasks/qe.py b/aiida_workgraph/tasks/qe.py deleted file mode 100644 index 8abc36c1..00000000 --- a/aiida_workgraph/tasks/qe.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Dict -from aiida_workgraph.task import Task - - -class AiiDAKpoint(Task): - identifier = "AiiDAKpoint" - name = "AiiDAKpoint" - node_type = "data" - catalog = "Test" - - kwargs = ["mesh", "offset"] - - def create_properties(self) -> None: - self.properties.new("AiiDAIntVector", "mesh", default=[1, 1, 1], size=3) - self.properties.new("AiiDAIntVector", "offset", default=[0, 0, 0], size=3) - - def create_sockets(self) -> None: - self.outputs.new("Any", "Kpoint") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "KpointsData", - } - - -class AiiDAStructure(Task): - identifier = "AiiDAStructure" - name = "AiiDAStructure" - node_type = "data" - catalog = "Test" - - kwargs = ["cell", "kinds", "pbc1", "pbc2", "pbc3", "sites"] - - def create_properties(self) -> None: - self.properties.new("BaseList", "cell", default=[]) - self.properties.new("BaseList", "kinds", default=[]) - self.properties.new("Bool", "pbc1", default=True) - self.properties.new("Bool", "pbc2", default=True) - self.properties.new("Bool", "pbc3", default=True) - self.properties.new("BaseList", "sites", default=[]) - - def create_sockets(self) -> None: - self.outputs.new("Any", "Structure") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "StructureData", - } - - -class AiiDAPWPseudo(Task): - identifier = "AiiDAPWPseudo" - name = "AiiDAPWPseudo" - node_type = "Normal" - catalog = "Test" - - args = ["psuedo_familay", "structure"] - - def create_properties(self) -> None: - self.properties.new( - "AiiDAString", "psuedo_familay", default="SSSP/1.3/PBEsol/efficiency" - ) - - def create_sockets(self) -> None: - self.inputs.new("Any", "structure") - self.outputs.new("Any", "Pseudo") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida_workgraph.executors.qe", - "name": "get_pseudo_from_structure", - "type": "function", - } diff --git a/aiida_workgraph/tasks/test.py b/aiida_workgraph/tasks/test.py index 7ea58597..fde54802 100644 --- a/aiida_workgraph/tasks/test.py +++ b/aiida_workgraph/tasks/test.py @@ -2,175 +2,10 @@ from aiida_workgraph.task import Task -class AiiDAInt(Task): - identifier = "AiiDAInt" - name = "AiiDAInt" - node_type = "data" - catalog = "Test" - - args = ["value"] - kwargs = ["t"] - - def create_properties(self) -> None: - self.properties.new("AiiDAFloat", "t", default=1.0) - - def create_sockets(self) -> None: - inp = self.inputs.new("Any", "value", default=0.0) - inp.add_property("AiiDAInt", default=1.0) - self.outputs.new("AiiDAInt", "result") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "Int", - } - - -class AiiDAFloat(Task): - identifier = "AiiDAFloat" - name = "AiiDAFloat" - node_type = "data" - catalog = "Test" - - args = ["value"] - kwargs = ["t"] - - def create_properties(self) -> None: - self.properties.new("AiiDAFloat", "t", default=1.0) - - def create_sockets(self) -> None: - self.inputs.new("AiiDAFloat", "value", default=0.0) - self.outputs.new("AiiDAFloat", "result") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "Float", - } - - -class AiiDAString(Task): - identifier = "AiiDAString" - name = "AiiDAString" - node_type = "data" - catalog = "Test" - - args = ["value"] - kwargs = ["t"] - - def create_properties(self) -> None: - self.properties.new("AiiDAFloat", "t", default=1.0) - - def create_sockets(self) -> None: - self.inputs.new("AiiDAString", "value", default="") - self.outputs.new("AiiDAString", "result") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "Str", - } - - -class AiiDAList(Task): - identifier = "AiiDAList" - name = "AiiDAList" - node_type = "data" - catalog = "Test" - - args = ["value"] - - def create_properties(self) -> None: - self.properties.new("BaseList", "value", default=[]) - - def create_sockets(self) -> None: - self.outputs.new("Any", "Parameters") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "List", - } - - -class AiiDADict(Task): - identifier = "AiiDADict" - name = "AiiDADict" - node_type = "data" - catalog = "Test" - - args = ["value"] - - def create_properties(self) -> None: - self.properties.new("BaseDict", "value", default={}) - - def create_sockets(self) -> None: - self.outputs.new("Any", "Parameters") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "Dict", - } - - -class AiiDANode(Task): - """AiiDANode""" - - identifier = "AiiDANode" - name = "AiiDANode" - node_type = "node" - catalog = "Test" - kwargs = ["identifier", "pk", "uuid", "label"] - - def create_properties(self) -> None: - pass - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("Any", "identifier") - self.inputs.new("Any", "pk") - self.inputs.new("Any", "uuid") - self.inputs.new("Any", "label") - self.outputs.new("Any", "node") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "load_node", - } - - -class AiiDACode(Task): - """AiiDACode""" - - identifier = "AiiDACode" - name = "AiiDACode" - node_type = "node" - catalog = "Test" - kwargs = ["identifier", "pk", "uuid", "label"] - - def create_sockets(self) -> None: - self.inputs.clear() - self.outputs.clear() - self.inputs.new("Any", "identifier") - self.inputs.new("Any", "pk") - self.inputs.new("Any", "uuid") - self.inputs.new("Any", "label") - self.outputs.new("Any", "Code") - - def get_executor(self) -> Dict[str, str]: - return { - "path": "aiida.orm", - "name": "load_code", - } - - -class AiiDAAdd(Task): +class TestAdd(Task): - identifier: str = "AiiDAAdd" - name = "AiiDAAdd" + identifier: str = "workgraph.test_add" + name = "TestAAdd" node_type = "CALCFUNCTION" catalog = "Test" @@ -178,16 +13,16 @@ class AiiDAAdd(Task): kwargs = ["t"] def create_properties(self) -> None: - self.properties.new("AiiDAFloat", "t", default=1.0) + self.properties.new("workgraph.aiida_float", "t", default=1.0) def create_sockets(self) -> None: self.inputs.clear() self.outputs.clear() - inp = self.inputs.new("AiiDAFloat", "x") - inp.add_property("AiiDAFloat", "x", default=0.0) - inp = self.inputs.new("AiiDAFloat", "y") - inp.add_property("AiiDAFloat", "y", default=0.0) - self.outputs.new("AiiDAFloat", "sum") + inp = self.inputs.new("workgraph.aiida_float", "x") + inp.add_property("workgraph.aiida_float", "x", default=0.0) + inp = self.inputs.new("workgraph.aiida_float", "y") + inp.add_property("workgraph.aiida_float", "y", default=0.0) + self.outputs.new("workgraph.aiida_float", "sum") def get_executor(self) -> Dict[str, str]: return { @@ -196,10 +31,10 @@ def get_executor(self) -> Dict[str, str]: } -class AiiDAGreater(Task): +class TestGreater(Task): - identifier: str = "AiiDAGreater" - name = "AiiDAGreater" + identifier: str = "workgraph.test_greater" + name = "TestGreater" node_type = "CALCFUNCTION" catalog = "Test" kwargs = ["x", "y"] @@ -210,8 +45,8 @@ def create_properties(self) -> None: def create_sockets(self) -> None: self.inputs.clear() self.outputs.clear() - self.inputs.new("AiiDAFloat", "x") - self.inputs.new("AiiDAFloat", "y") + self.inputs.new("workgraph.aiida_float", "x") + self.inputs.new("workgraph.aiida_float", "y") self.outputs.new("AiiDABool", "result") def get_executor(self) -> Dict[str, str]: @@ -221,10 +56,10 @@ def get_executor(self) -> Dict[str, str]: } -class AiiDASumDiff(Task): +class TestSumDiff(Task): - identifier: str = "AiiDASumDiff" - name = "AiiDASumDiff" + identifier: str = "workgraph.test_sum_diff" + name = "TestSumDiff" node_type = "CALCFUNCTION" catalog = "Test" @@ -232,17 +67,17 @@ class AiiDASumDiff(Task): kwargs = ["t"] def create_properties(self) -> None: - self.properties.new("AiiDAFloat", "t", default=1.0) + self.properties.new("workgraph.aiida_float", "t", default=1.0) def create_sockets(self) -> None: self.inputs.clear() self.outputs.clear() - inp = self.inputs.new("AiiDAFloat", "x") - inp.add_property("AiiDAFloat", "x", default=0.0) - inp = self.inputs.new("AiiDAFloat", "y") - inp.add_property("AiiDAFloat", "y", default=0.0) - self.outputs.new("AiiDAFloat", "sum") - self.outputs.new("AiiDAFloat", "diff") + inp = self.inputs.new("workgraph.aiida_float", "x") + inp.add_property("workgraph.aiida_float", "x", default=0.0) + inp = self.inputs.new("workgraph.aiida_float", "y") + inp.add_property("workgraph.aiida_float", "y", default=0.0) + self.outputs.new("workgraph.aiida_float", "sum") + self.outputs.new("workgraph.aiida_float", "diff") def get_executor(self) -> Dict[str, str]: return { @@ -251,10 +86,10 @@ def get_executor(self) -> Dict[str, str]: } -class AiiDAArithmeticMultiplyAdd(Task): +class TestArithmeticMultiplyAdd(Task): - identifier: str = "AiiDAArithmeticMultiplyAdd" - name = "AiiDAArithmeticMultiplyAdd" + identifier: str = "workgraph.test_arithmetic_multiply_add" + name = "TestArithmeticMultiplyAdd" node_type = "WORKCHAIN" catalog = "Test" kwargs = ["code", "x", "y", "z"] @@ -265,14 +100,14 @@ def create_properties(self) -> None: def create_sockets(self) -> None: self.inputs.clear() self.outputs.clear() - self.inputs.new("Any", "code") - inp = self.inputs.new("AiiDAInt", "x") - inp.add_property("AiiDAInt", "x", default=0.0) - inp = self.inputs.new("AiiDAInt", "y") - inp.add_property("AiiDAInt", "y", default=0.0) - inp = self.inputs.new("AiiDAInt", "z") - inp.add_property("AiiDAInt", "z", default=0.0) - self.outputs.new("AiiDAInt", "result") + self.inputs.new("workgraph.any", "code") + inp = self.inputs.new("workgraph.aiida_int", "x") + inp.add_property("workgraph.aiida_int", "x", default=0.0) + inp = self.inputs.new("workgraph.aiida_int", "y") + inp.add_property("workgraph.aiida_int", "y", default=0.0) + inp = self.inputs.new("workgraph.aiida_int", "z") + inp.add_property("workgraph.aiida_int", "z", default=0.0) + self.outputs.new("workgraph.aiida_int", "result") def get_executor(self) -> Dict[str, str]: return { diff --git a/docs/source/built-in/pythonjob.ipynb b/docs/source/built-in/pythonjob.ipynb index 05b186d5..662a85c7 100644 --- a/docs/source/built-in/pythonjob.ipynb +++ b/docs/source/built-in/pythonjob.ipynb @@ -1511,22 +1511,22 @@ "\n", "### Defining Namespace Outputs\n", "\n", - "To declare a namespace output, set the `identifier` to `Namespace` in the `outputs` parameter of the `@task` decorator. For example:\n", + "To declare a namespace output, set the `identifier` to `workgraph.namespace` in the `outputs` parameter of the `@task` decorator. For example:\n", "\n", "```python\n", - "@task(outputs=[{\"name\": \"structures\", \"identifier\": \"Namespace\"}])\n", + "@task(outputs=[{\"name\": \"structures\", \"identifier\": \"workgraph.namespace\"}])\n", "def generate_surface_slabs():\n", " # Function logic to generate surface slabs\n", " return {\"slab1\": slab_data1, \"slab2\": slab_data2}\n", "```\n", "\n", "\n", - "One can also define nested namespace outputs by specifying the `identifier` as `Namespace` for sub-dictionaries within the namespace output. For example, here we define `add_multiply.add` as a nested namespace output:\n", + "One can also define nested namespace outputs by specifying the `identifier` as `workgraph.namespace` for sub-dictionaries within the namespace output. For example, here we define `add_multiply.add` as a nested namespace output:\n", "\n", "```python\n", "@task(\n", - " outputs=[{\"name\": \"add_multiply\", \"identifier\": \"Namespace\"},\n", - " {\"name\": \"add_multiply.add\", \"identifier\": \"Namespace\"},\n", + " outputs=[{\"name\": \"add_multiply\", \"identifier\": \"workgraph.namespace\"},\n", + " {\"name\": \"add_multiply.add\", \"identifier\": \"workgraph.namespace\"},\n", " {\"name\": \"minus\"},\n", " ]\n", ")\n", @@ -1548,7 +1548,7 @@ "```python\n", "@task(\n", " outputs=[\n", - " {\"identifier\": \"Namespace\", \"name\": \"add_multiply\"},\n", + " {\"identifier\": \"workgraph.namespace\", \"name\": \"add_multiply\"},\n", " {\"name\": \"add_multiply.add\"},\n", " {\"name\": \"add_multiply.multiply\"},\n", " {\"name\": \"minus\"},\n", @@ -1613,7 +1613,7 @@ "\n", "load_profile()\n", "\n", - "@task(outputs=[{\"name\": \"scaled_atoms\", \"identifier\": \"Namespace\"},\n", + "@task(outputs=[{\"name\": \"scaled_atoms\", \"identifier\": \"workgraph.namespace\"},\n", " {\"name\": \"volumes\"}]\n", ")\n", "def generate_scaled_atoms(atoms: Atoms, scales: list) -> dict:\n", diff --git a/docs/source/development/custom_task.rst b/docs/source/development/custom_task.rst new file mode 100644 index 00000000..f1068f34 --- /dev/null +++ b/docs/source/development/custom_task.rst @@ -0,0 +1,12 @@ +================== +Custom Task +================== + + +Suggested naming rule for entry points: + +- Format: {package_name}.{short_name_for_class}. If the plugin name starts with 'aiida', such as 'aiida-xxx', you may omit 'aiida' for brevity, e.g., 'xxx.add'. Aim to maintain consistency with the name used when registering with the AiiDA registry. +- The name is case-insensitive, thus `xxx.add` is the same as `xxx.Add`. +- Naming Style: Use snake case. For instance, use 'xxx.test_add' for the class 'TestAddTask'. Typically, the suffix 'task' is unnecessary, thus avoid using 'xxx.test_add_task'. + +The rule is used for Task, Socket and Property. diff --git a/docs/source/development/index.rst b/docs/source/development/index.rst index 9bba3495..51ec2787 100644 --- a/docs/source/development/index.rst +++ b/docs/source/development/index.rst @@ -4,28 +4,12 @@ Development This section contains information for developers. -Pre-commit and Tests ---------------------- -To contribute to this repository, please enable pre-commit to ensure that the code in commits conforms to the standards. - -.. code-block:: console - - $ pip install -e .[tests, pre-commit] - $ pre-commit install - -Widget ----------------- -See the `README.md `_. - -Web app ----------------- -See the `README.md `_. - - .. toctree:: :maxdepth: 1 :caption: Other Contents: + custom_task data_serialization + test python_task diff --git a/docs/source/development/test.rst b/docs/source/development/test.rst new file mode 100644 index 00000000..593a860b --- /dev/null +++ b/docs/source/development/test.rst @@ -0,0 +1,18 @@ +====================== +Pre-commit and Tests +====================== + +To contribute to this repository, please enable pre-commit to ensure that the code in commits conforms to the standards. + +.. code-block:: console + + $ pip install -e .[tests, pre-commit] + $ pre-commit install + +Widget +---------------- +See the `README.md `_. + +Web app +---------------- +See the `README.md `_. diff --git a/pyproject.toml b/pyproject.toml index c71565e0..fc9d672c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph>=0.0.11", + "node-graph>=0.0.12", "aiida-core>=2.3", "cloudpickle", "aiida-shell", @@ -94,13 +94,46 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "process.workflow.workgraph" = "aiida_workgraph.orm.workgraph:WorkGraphNode" [project.entry-points."aiida_workgraph.task"] -"aiida" = "aiida_workgraph.tasks:task_list" +"workgraph.while" = "aiida_workgraph.tasks.builtins:While" +"workgraph.gather" = "aiida_workgraph.tasks.builtins:Gather" +"workgraph.to_ctx" = "aiida_workgraph.tasks.builtins:ToCtx" +"workgraph.from_ctx" = "aiida_workgraph.tasks.builtins:FromCtx" +"workgraph.aiida_int" = "aiida_workgraph.tasks.builtins:AiiDAInt" +"workgraph.aiida_float" = "aiida_workgraph.tasks.builtins:AiiDAFloat" +"workgraph.aiida_string" = "aiida_workgraph.tasks.builtins:AiiDAString" +"workgraph.aiida_list" = "aiida_workgraph.tasks.builtins:AiiDAList" +"workgraph.aiida_dict" = "aiida_workgraph.tasks.builtins:AiiDADict" +"workgraph.aiida_node" = "aiida_workgraph.tasks.builtins:AiiDANode" +"workgraph.aiida_code" = "aiida_workgraph.tasks.builtins:AiiDACode" +"workgraph.test_add" = "aiida_workgraph.tasks.test:TestAdd" +"workgraph.test_greater" = "aiida_workgraph.tasks.test:TestGreater" +"workgraph.test_sum_diff" = "aiida_workgraph.tasks.test:TestSumDiff" +"workgraph.test_arithmetic_multiply_add" = "aiida_workgraph.tasks.test:TestArithmeticMultiplyAdd" [project.entry-points."aiida_workgraph.property"] -"aiida" = "aiida_workgraph.properties.built_in:property_list" +"workgraph.any" = "aiida_workgraph.properties.builtins:PropertyAny" +"workgraph.aiida_int" = "aiida_workgraph.properties.builtins:PropertyAiiDAInt" +"workgraph.aiida_float" = "aiida_workgraph.properties.builtins:PropertyAiiDAFloat" +"workgraph.aiida_string" = "aiida_workgraph.properties.builtins:PropertyAiiDAString" +"workgraph.aiida_bool" = "aiida_workgraph.properties.builtins:PropertyAiiDABool" +"workgraph.aiida_int_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAIntVector" +"workgraph.aiida_float_vector" = "aiida_workgraph.properties.builtins:PropertyAiiDAFloatVector" +"workgraph.aiida_aiida_dict" = "aiida_workgraph.properties.builtins:PropertyAiiDADict" [project.entry-points."aiida_workgraph.socket"] -"aiida" = "aiida_workgraph.sockets.built_in:socket_list" +"workgraph.any" = "aiida_workgraph.sockets.builtins:SocketAny" +"workgraph.namespace" = "aiida_workgraph.sockets.builtins:SocketNamespace" +"workgraph.int" = "node_graph.sockets.builtins:SocketInt" +"workgraph.float" = "node_graph.sockets.builtins:SocketFloat" +"workgraph.string" = "node_graph.sockets.builtins:SocketString" +"workgraph.bool" = "node_graph.sockets.builtins:SocketBool" +"workgraph.aiida_int" = "aiida_workgraph.sockets.builtins:SocketAiiDAInt" +"workgraph.aiida_float" = "aiida_workgraph.sockets.builtins:SocketAiiDAFloat" +"workgraph.aiida_string" = "aiida_workgraph.sockets.builtins:SocketAiiDAString" +"workgraph.aiida_bool" = "aiida_workgraph.sockets.builtins:SocketAiiDABool" +"workgraph.aiida_int_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAIntVector" +"workgraph.aiida_float_vector" = "aiida_workgraph.sockets.builtins:SocketAiiDAFloatVector" + [tool.flit.sdist] diff --git a/tests/conftest.py b/tests/conftest.py index a985fcbf..016c9df9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -61,8 +61,8 @@ def wg_calcfunction() -> WorkGraph: """A workgraph with calcfunction.""" wg = WorkGraph(name="test_debug_math") - sumdiff1 = wg.add_task("AiiDASumDiff", "sumdiff1", x=2, y=3) - sumdiff2 = wg.add_task("AiiDASumDiff", "sumdiff2", x=4) + sumdiff1 = wg.add_task("workgraph.test_sum_diff", "sumdiff1", x=2, y=3) + sumdiff2 = wg.add_task("workgraph.test_sum_diff", "sumdiff2", x=4) wg.add_link(sumdiff1.outputs[0], sumdiff2.inputs[1]) return wg @@ -85,14 +85,14 @@ def wg_workchain(add_code) -> WorkGraph: """A workgraph with workchain.""" wg = WorkGraph(name="test_debug_math") - int1 = wg.add_task("AiiDANode", "int1", pk=Int(2).store().pk) - int2 = wg.add_task("AiiDANode", "int2", pk=Int(3).store().pk) - code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk) + int1 = wg.add_task("workgraph.aiida_node", "int1", pk=Int(2).store().pk) + int2 = wg.add_task("workgraph.aiida_node", "int2", pk=Int(3).store().pk) + code1 = wg.add_task("workgraph.aiida_code", "code1", pk=add_code.pk) multiply_add1 = wg.add_task( - "AiiDAArithmeticMultiplyAdd", "multiply_add1", x=Int(4).store() + "workgraph.test_arithmetic_multiply_add", "multiply_add1", x=Int(4).store() ) multiply_add2 = wg.add_task( - "AiiDAArithmeticMultiplyAdd", + "workgraph.test_arithmetic_multiply_add", "multiply_add2", x=Int(2).store(), y=Int(3).store(), @@ -210,25 +210,6 @@ def structure_si() -> StructureData: return structure_si -@pytest.fixture -def wg_structure_si() -> WorkGraph: - wg = WorkGraph(name="test_structure") - structure1 = wg.add_task("AiiDAStructure", "structure1") - data = { - "cell": [[0.0, 2.715, 2.715], [2.715, 0.0, 2.715], [2.715, 2.715, 0.0]], - "kinds": [{"mass": 28.085, "name": "Si", "symbols": ["Si"], "weights": [1.0]}], - "pbc1": True, - "pbc2": True, - "pbc3": True, - "sites": [ - {"kind_name": "Si", "position": [0.0, 0.0, 0.0]}, - {"kind_name": "Si", "position": [1.3575, 1.3575, 1.3575]}, - ], - } - structure1.set(data) - return wg - - @pytest.fixture def wg_engine(decorated_add, add_code) -> WorkGraph: """Use to test the engine.""" diff --git a/tests/datas/test_calcfunction.yaml b/tests/datas/test_calcfunction.yaml index 989571d2..0723f109 100644 --- a/tests/datas/test_calcfunction.yaml +++ b/tests/datas/test_calcfunction.yaml @@ -1,15 +1,15 @@ name: test_calcfunction description: 'This is a test to run workgraph using yaml file.' metadata: - version: scinode@0.1.0 - platform: scinode + version: workgraph@0.1.0 + platform: aiida-workgraph worker_name: localhost tasks: - - identifier: AiiDAFloat + - identifier: workgraph.aiida_float name: float1 properties: value: 3.0 - - identifier: AiiDASumDiff + - identifier: workgraph.test_sum_diff name: sumdiff1 properties: x: 2.0 @@ -17,7 +17,7 @@ tasks: - to_socket: y from_node: float1 from_socket: 0 - - identifier: AiiDASumDiff + - identifier: workgraph.test_sum_diff name: sumdiff2 properties: x: 4.0 diff --git a/tests/datas/test_calcjob.yaml b/tests/datas/test_calcjob.yaml index deaccb11..3096ce51 100644 --- a/tests/datas/test_calcjob.yaml +++ b/tests/datas/test_calcjob.yaml @@ -1,19 +1,19 @@ name: test_calcjob description: 'This is a test to run workgraph using yaml file.' metadata: - version: scinode@0.1.0 - platform: scinode + version: workgraph@0.1.0 + platform: aiida-workgraph worker_name: localhost tasks: - - identifier: AiiDAInt + - identifier: workgraph.aiida_int name: int1 properties: value: 3 - - identifier: AiiDACode + - identifier: workgraph.aiida_code name: code1 properties: value: "add@localhost" - - identifier: AiiDAArithmeticAdd + - identifier: workgraph.test_arithmetic_add name: add1 properties: x: 2 @@ -24,7 +24,7 @@ tasks: - to_socket: y from_node: int1 from_socket: 0 - - identifier: AiiDAArithmeticAdd + - identifier: workgraph.test_arithmetic_add name: add2 properties: x: 4 diff --git a/tests/test_build_task.py b/tests/test_build_task.py index aed8fe72..1f0297d9 100644 --- a/tests/test_build_task.py +++ b/tests/test_build_task.py @@ -52,8 +52,8 @@ def add_minus(x, y): AddTask = build_task( add_minus, outputs=[ - {"identifier": "Any", "name": "sum"}, - {"identifier": "Any", "name": "difference"}, + {"identifier": "workgraph.any", "name": "sum"}, + {"identifier": "workgraph.any", "name": "difference"}, ], ) assert issubclass(AddTask, Task) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 83dba3af..a09e947c 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -184,9 +184,9 @@ def test_decorator_workfunction(decorated_add_multiply: Callable) -> None: def test_decorator_graph_builder(decorated_add_multiply_group: Callable) -> None: """Test graph build.""" wg = WorkGraph("test_graph_builder") - add1 = wg.add_task("AiiDAAdd", "add1", x=2, y=3) + add1 = wg.add_task("workgraph.test_add", "add1", x=2, y=3) add_multiply1 = wg.add_task(decorated_add_multiply_group, "add_multiply1", y=3, z=4) - sum_diff1 = wg.add_task("AiiDASumDiff", "sum_diff1") + sum_diff1 = wg.add_task("workgraph.test_sum_diff", "sum_diff1") wg.add_link(add1.outputs[0], add_multiply1.inputs["x"]) wg.add_link(add_multiply1.outputs["result"], sum_diff1.inputs["x"]) wg.submit(wait=True) diff --git a/tests/test_link.py b/tests/test_link.py index 34558431..e87c98b3 100644 --- a/tests/test_link.py +++ b/tests/test_link.py @@ -16,10 +16,10 @@ def sum(datas): return Float(total) wg = WorkGraph(name="test_multiply_link") - float1 = wg.add_task("AiiDANode", pk=Float(1.0).store().pk) - float2 = wg.add_task("AiiDANode", pk=Float(2.0).store().pk) - float3 = wg.add_task("AiiDANode", pk=Float(3.0).store().pk) - gather1 = wg.add_task("AiiDAGather", "gather1") + float1 = wg.add_task("workgraph.aiida_node", pk=Float(1.0).store().pk) + float2 = wg.add_task("workgraph.aiida_node", pk=Float(2.0).store().pk) + float3 = wg.add_task("workgraph.aiida_node", pk=Float(3.0).store().pk) + gather1 = wg.add_task("workgraph.gather", "gather1") sum1 = wg.add_task(sum, "sum1") wg.add_link(float1.outputs[0], gather1.inputs[0]) wg.add_link(float2.outputs[0], gather1.inputs[0]) diff --git a/tests/test_python.py b/tests/test_python.py index 579a30ab..da09b150 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -9,8 +9,8 @@ def test_decorator(fixture_localhost): @task.pythonjob( outputs=[ - {"identifier": "Any", "name": "sum"}, - {"identifier": "Any", "name": "diff"}, + {"identifier": "workgraph.any", "name": "sum"}, + {"identifier": "workgraph.any", "name": "diff"}, ] ) def add(x, y): @@ -97,8 +97,8 @@ def test_PythonJob_outputs(fixture_localhost): @task( outputs=[ - {"identifier": "Any", "name": "sum"}, - {"identifier": "Any", "name": "diff"}, + {"identifier": "workgraph.any", "name": "sum"}, + {"identifier": "workgraph.any", "name": "diff"}, ] ) def add(x, y): @@ -120,6 +120,7 @@ def add(x, y): assert wg.tasks["add"].outputs["diff"].value.value == -1 +@pytest.mark.usefixtures("started_daemon_client") def test_PythonJob_namespace_output(fixture_localhost): """Test function with namespace output and input.""" @@ -128,11 +129,11 @@ def test_PythonJob_namespace_output(fixture_localhost): outputs=[ { "name": "add_multiply", - "identifier": "Namespace", + "identifier": "workgraph.namespace", }, { "name": "add_multiply.add", - "identifier": "Namespace", + "identifier": "workgraph.namespace", }, {"name": "minus"}, ] @@ -167,7 +168,7 @@ def test_PythonJob_namespace_output_input(fixture_localhost): # output namespace @task( outputs=[ - {"identifier": "Namespace", "name": "add_multiply"}, + {"identifier": "workgraph.namespace", "name": "add_multiply"}, {"name": "add_multiply.add"}, {"name": "add_multiply.multiply"}, {"name": "minus"}, diff --git a/tests/test_shell.py b/tests/test_shell.py index 7cfb758a..7d550221 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -95,7 +95,7 @@ def parser(self, dirpath): nodes={"expression": job1.outputs["stdout"]}, parser=PickledData(parser), parser_outputs=[ - {"identifier": "Any", "name": "result"} + {"identifier": "workgraph.any", "name": "result"} ], # add a "result" output socket from the parser ) diff --git a/tests/test_socket.py b/tests/test_socket.py index 9a9198fd..94eaaa7f 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -25,7 +25,8 @@ def add(x: int, y: float): assert "is not" in str(excinfo.value) and int.__name__ in str(excinfo.value) # This should be successful add1.set({"x": 1, "y": 2.0}) - wg.submit(wait=True) + # wg.submit(wait=True) + wg.run() assert wg.state.upper() == "FINISHED" assert multiply1.outputs["result"].value == 6.0 diff --git a/tests/test_workchain.py b/tests/test_workchain.py index c6eed195..85f59b1f 100644 --- a/tests/test_workchain.py +++ b/tests/test_workchain.py @@ -31,7 +31,7 @@ def test_build_workchain(add_code): from aiida_workgraph import WorkGraph wg = WorkGraph(name="test_debug_math") - code1 = wg.add_task("AiiDACode", "code1", pk=add_code.pk) + code1 = wg.add_task("workgraph.aiida_code", "code1", pk=add_code.pk) multiply_add1 = wg.add_task( MultiplyAddWorkChain, "multiply_add1", diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index 2d262810..e8ed049e 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -78,7 +78,12 @@ def test_restart(wg_calcfunction): Load the workgraph, modify the task, and restart the workgraph. Only the modified node and its child tasks will be rerun.""" wg = wg_calcfunction - wg.add_task("AiiDASumDiff", "sumdiff3", x=4, y=wg.tasks["sumdiff2"].outputs["sum"]) + wg.add_task( + "workgraph.test_sum_diff", + "sumdiff3", + x=4, + y=wg.tasks["sumdiff2"].outputs["sum"], + ) wg.name = "test_restart_0" wg.submit(wait=True) wg1 = WorkGraph.load(wg.process.pk) @@ -97,7 +102,7 @@ def test_extend_workgraph(decorated_add_multiply_group): from aiida_workgraph import WorkGraph wg = WorkGraph("test_graph_build") - add1 = wg.add_task("AiiDAAdd", "add1", x=2, y=3) + add1 = wg.add_task("workgraph.test_add", "add1", x=2, y=3) add_multiply_wg = decorated_add_multiply_group(x=0, y=4, z=5) # extend workgraph wg.extend(add_multiply_wg, prefix="group_") From 2b568e7c5ef40eb96390a7bd84642dccb9bdea80 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Fri, 9 Aug 2024 22:46:40 +0200 Subject: [PATCH 07/11] `PythonJob` uses task name as the process label (#210) Provide an input `process_label` so that the user can set the process label manually. In the case of WorkGraph, the process label will be set as: `PythonJob<{task['name']}` --- aiida_workgraph/calculations/python.py | 8 +++++++- aiida_workgraph/engine/utils.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/aiida_workgraph/calculations/python.py b/aiida_workgraph/calculations/python.py index 2119dbca..92f2f3d2 100644 --- a/aiida_workgraph/calculations/python.py +++ b/aiida_workgraph/calculations/python.py @@ -50,6 +50,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] spec.input( "function_name", valid_type=Str, serializer=to_aiida_type, required=False ) + spec.input( + "process_label", valid_type=Str, serializer=to_aiida_type, required=False + ) spec.input_namespace( "function_kwargs", valid_type=Data, required=False ) # , serializer=serialize_to_aiida_nodes) @@ -137,7 +140,10 @@ def _build_process_label(self) -> str: :returns: The process label to use for ``ProcessNode`` instances. """ - return f"PythonJob<{self.inputs.function_name.value}>" + if self.inputs.process_label: + return self.inputs.process_label.value + else: + return f"PythonJob<{self.inputs.function_name.value}>" def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index e62fc5f5..ee091c17 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -95,6 +95,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: function_kwargs = serialize_to_aiida_nodes(function_kwargs) # transfer the args to kwargs inputs = { + "process_label": f"PythonJob<{task['name']}", "function_source_code": orm.Str(function_source_code), "function_name": orm.Str(function_name), "code": code, From 0976479fd64e9093a17dc8b07aca3f6770ee4e04 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Fri, 9 Aug 2024 23:17:45 +0200 Subject: [PATCH 08/11] `PythonJob`: stores the `inputs.pickle` file before passing to the `local_copy_list` (#211) --- aiida_workgraph/calculations/python.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aiida_workgraph/calculations/python.py b/aiida_workgraph/calculations/python.py index 92f2f3d2..7a65c563 100644 --- a/aiida_workgraph/calculations/python.py +++ b/aiida_workgraph/calculations/python.py @@ -263,9 +263,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: dirpath = pathlib.Path(folder._abspath) with folder.open(filename, "wb") as handle: pickle.dump(input_values, handle) - # create a singlefiledata object for the pickled data - file_data = SinglefileData(file=f"{dirpath}/{filename}") - local_copy_list.append((file_data.uuid, file_data.filename, filename)) + # create a singlefiledata object for the pickled data + file_data = SinglefileData(file=f"{dirpath}/{filename}") + file_data.store() + local_copy_list.append((file_data.uuid, file_data.filename, filename)) codeinfo = CodeInfo() codeinfo.stdin_name = self.options.input_filename From 58ffacafa7a950a4b54ce15f832db5ebfcf4b436 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Fri, 9 Aug 2024 23:18:15 +0200 Subject: [PATCH 09/11] Release 0.3.15 --- aiida_workgraph/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida_workgraph/__init__.py b/aiida_workgraph/__init__.py index 335a5ea7..05582eda 100644 --- a/aiida_workgraph/__init__.py +++ b/aiida_workgraph/__init__.py @@ -9,6 +9,6 @@ from .task import Task from .decorator import task, build_task -__version__ = "0.3.14" +__version__ = "0.3.15" __all__ = ["WorkGraph", "Task", "task", "build_task"] From e378b19a212d141d2b402f6eaf86a253c6db98f4 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Fri, 9 Aug 2024 23:23:43 +0200 Subject: [PATCH 10/11] Add an entry point for WorkGraphEngine to remove warning --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index fc9d672c..1584b1e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,9 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points."aiida.parsers"] "workgraph.python" = "aiida_workgraph.calculations.python_parser:PythonParser" +[project.entry-points.'aiida.workflows'] +"workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine" + [project.entry-points."aiida.data"] "workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" "workgraph.ase.atoms.Atoms" = "aiida_workgraph.orm.atoms:AtomsData" From 1cdaa442e51eb57a2c102c21138c27c5bbc7c2a1 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Sat, 10 Aug 2024 17:54:50 +0200 Subject: [PATCH 11/11] Set the node' label as the workgraph/task's name. (#213) Set the node' label as the workgraph/task's name when the process node is created. This is very useful for querying. For example, for `WorkGraph` and `PythonJob`, the `node.label` will be `abc_name`.Note: `node.label` is mutable, thus user can modify it after. --- aiida_workgraph/calculations/python.py | 6 ++++++ aiida_workgraph/engine/utils.py | 2 +- aiida_workgraph/engine/workgraph.py | 1 + tests/test_python.py | 15 +++++++++------ tests/test_workgraph.py | 1 + 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/aiida_workgraph/calculations/python.py b/aiida_workgraph/calculations/python.py index 7a65c563..8a469863 100644 --- a/aiida_workgraph/calculations/python.py +++ b/aiida_workgraph/calculations/python.py @@ -145,6 +145,12 @@ def _build_process_label(self) -> str: else: return f"PythonJob<{self.inputs.function_name.value}>" + def on_create(self) -> None: + """Called when a Process is created.""" + + super().on_create() + self.node.label = self.inputs.process_label.value + def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index ee091c17..6c2d5a95 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -95,7 +95,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: function_kwargs = serialize_to_aiida_nodes(function_kwargs) # transfer the args to kwargs inputs = { - "process_label": f"PythonJob<{task['name']}", + "process_label": f"PythonJob<{task['name']}>", "function_source_code": orm.Str(function_source_code), "function_name": orm.Str(function_name), "code": code, diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index 063dc5d7..91dc801b 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -416,6 +416,7 @@ def on_create(self) -> None: ) saver = WorkGraphSaver(self.node, wgdata, restart_process=restart_process) saver.save() + self.node.label = wgdata["name"] def setup(self) -> None: # track if the awaitable callback is added to the runner diff --git a/tests/test_python.py b/tests/test_python.py index da09b150..d0a13906 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -24,23 +24,26 @@ def multiply(x: Any, y: Any) -> Any: wg = WorkGraph("test_PythonJob_outputs") wg.add_task( add, - name="add", + name="add1", x=1, y=2, computer="localhost", ) wg.add_task( decorted_multiply, - name="multiply", - x=wg.tasks["add"].outputs["sum"], + name="multiply1", + x=wg.tasks["add1"].outputs["sum"], y=3, computer="localhost", ) # wg.submit(wait=True) wg.run() - assert wg.tasks["add"].outputs["sum"].value.value == 3 - assert wg.tasks["add"].outputs["diff"].value.value == -1 - assert wg.tasks["multiply"].outputs["result"].value.value == 9 + assert wg.tasks["add1"].outputs["sum"].value.value == 3 + assert wg.tasks["add1"].outputs["diff"].value.value == -1 + assert wg.tasks["multiply1"].outputs["result"].value.value == 9 + # process_label and label + assert wg.tasks["add1"].node.process_label == "PythonJob" + assert wg.tasks["add1"].node.label == "add1" @pytest.mark.usefixtures("started_daemon_client") diff --git a/tests/test_workgraph.py b/tests/test_workgraph.py index e8ed049e..d2d6829f 100644 --- a/tests/test_workgraph.py +++ b/tests/test_workgraph.py @@ -39,6 +39,7 @@ def test_save_load(wg_calcfunction): wg.save() assert wg.process.process_state.value.upper() == "CREATED" assert wg.process.process_label == "WorkGraph" + assert wg.process.label == "test_save_load" wg2 = WorkGraph.load(wg.process.pk) assert len(wg.tasks) == len(wg2.tasks)