From 21eb8d05bd2c6362770c174d18152f1781df30fd Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Thu, 15 Aug 2024 14:15:35 +0200 Subject: [PATCH] Add `PickledFunction` data (#229) - Update the `GeneralData` - Implement the `PickledFunction` class as a subclass of orm.Data, the class handles - storing and retrieving serialized function - maintaining relevant metadata such as the function name, function source code, module name, Python version, and pickle protocol. - PythonJob adds `function` input, which accepts `PickledFunction` as input. - Add docs: use the `PythonJob` task outside the WorkGraph to run a Python function on a remote computer. For example, in a `WorkChain` or run a single `CalcJob` calculation. --- aiida_workgraph/calculations/python.py | 57 +++++-- aiida_workgraph/calculations/python_parser.py | 17 +- aiida_workgraph/decorator.py | 7 +- aiida_workgraph/engine/utils.py | 4 +- aiida_workgraph/orm/function_data.py | 155 ++++++++++++++++++ aiida_workgraph/orm/general_data.py | 75 ++++----- aiida_workgraph/utils/__init__.py | 89 ---------- docs/source/built-in/pythonjob.ipynb | 97 ++++++++++- pyproject.toml | 1 + tests/test_python.py | 6 +- 10 files changed, 349 insertions(+), 159 deletions(-) create mode 100644 aiida_workgraph/orm/function_data.py diff --git a/aiida_workgraph/calculations/python.py b/aiida_workgraph/calculations/python.py index 8a469863..3935d116 100644 --- a/aiida_workgraph/calculations/python.py +++ b/aiida_workgraph/calculations/python.py @@ -17,6 +17,7 @@ RemoteData, to_aiida_type, ) +from aiida_workgraph.orm.function_data import PickledFunction, to_pickled_function __all__ = ("PythonJob",) @@ -31,6 +32,7 @@ class PythonJob(CalcJob): _DEFAULT_INPUT_FILE = "script.py" _DEFAULT_OUTPUT_FILE = "aiida.out" + _DEFAULT_PARENT_FOLDER_NAME = "./parent_folder/" _default_parser = "workgraph.python" @@ -41,6 +43,12 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] :param spec: the calculation job process spec to define. """ super().define(spec) + spec.input( + "function", + valid_type=PickledFunction, + serializer=to_pickled_function, + required=False, + ) spec.input( "function_source_code", valid_type=Str, @@ -57,8 +65,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] "function_kwargs", valid_type=Data, required=False ) # , serializer=serialize_to_aiida_nodes) spec.input( - "output_info", + "function_outputs", valid_type=List, + default=lambda: List(), required=False, serializer=to_aiida_type, help="The information of the output ports", @@ -72,7 +81,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] spec.input( "parent_folder_name", valid_type=Str, - default=lambda: Str("./parent_folder/"), required=False, serializer=to_aiida_type, help="""Default name of the subfolder that you want to create in the working directory, @@ -140,16 +148,36 @@ def _build_process_label(self) -> str: :returns: The process label to use for ``ProcessNode`` instances. """ - if self.inputs.process_label: + if "process_label" in self.inputs: return self.inputs.process_label.value else: - return f"PythonJob<{self.inputs.function_name.value}>" + data = self.get_function_data() + return f"PythonJob<{data['function_name']}>" def on_create(self) -> None: """Called when a Process is created.""" super().on_create() - self.node.label = self.inputs.process_label.value + self.node.label = self._build_process_label() + + def get_function_data(self) -> dict[str, t.Any]: + """Get the function data. + + :returns: The function data. + """ + if "function" in self.inputs: + metadata = self.inputs.function.metadata + metadata["function_source_code"] = ( + metadata["import_statements"] + + "\n" + + metadata["function_source_code_without_decorator"] + ) + return metadata + else: + return { + "function_source_code": self.inputs.function_source_code.value, + "function_name": self.inputs.function_name.value, + } def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. @@ -169,21 +197,24 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: inputs = dict(self.inputs.function_kwargs) else: inputs = {} - # get the value of pickled function - function_source_code = self.inputs.function_source_code.value + if "parent_folder_name" in self.inputs: + parent_folder_name = self.inputs.parent_folder_name.value + else: + parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME + function_data = self.get_function_data() # create python script to run the function script = f""" import pickle # define the function -{function_source_code} +{function_data["function_source_code"]} # load the inputs from the pickle file with open('inputs.pickle', 'rb') as handle: inputs = pickle.load(handle) # run the function -result = {self.inputs.function_name.value}(**inputs) +result = {function_data["function_name"]}(**inputs) # save the result as a pickle file with open('results.pickle', 'wb') as handle: pickle.dump(result, handle) @@ -213,7 +244,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: ( source.computer.uuid, str(dirpath), - self.inputs.parent_folder_name.value, + parent_folder_name, ) ) elif isinstance(source, FolderData): @@ -222,12 +253,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: if self.inputs.parent_output_folder is not None else "" ) - local_copy_list.append( - (source.uuid, dirname, self.inputs.parent_folder_name.value) - ) + local_copy_list.append((source.uuid, dirname, parent_folder_name)) elif isinstance(source, SinglefileData): local_copy_list.append((source.uuid, source.filename, source.filename)) - if self.inputs.upload_files: + if "upload_files" in self.inputs: upload_files = self.inputs.upload_files for key, source in upload_files.items(): # replace "_dot_" with "." in the key diff --git a/aiida_workgraph/calculations/python_parser.py b/aiida_workgraph/calculations/python_parser.py index 6c02c4c3..4284f6ab 100644 --- a/aiida_workgraph/calculations/python_parser.py +++ b/aiida_workgraph/calculations/python_parser.py @@ -9,8 +9,8 @@ class PythonParser(Parser): def parse(self, **kwargs): """Parse the contents of the output files stored in the `retrieved` output node. - The outputs could be a namespce, e.g., - outputs=[ + The function_outputs could be a namespce, e.g., + function_outputs=[ {"identifier": "workgraph.namespace", "name": "add_multiply"}, {"name": "add_multiply.add"}, {"name": "add_multiply.multiply"}, @@ -19,11 +19,11 @@ def parse(self, **kwargs): """ import pickle - output_info = self.node.inputs.output_info.get_list() - # output_info exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved'] + function_outputs = self.node.inputs.function_outputs.get_list() + # function_outputs exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved'] self.output_list = [ data - for data in output_info + for data in function_outputs if data["name"] not in [ "_wait", @@ -86,7 +86,8 @@ def parse(self, **kwargs): self.out(output["name"], output["value"]) except OSError: return self.exit_codes.ERROR_READING_OUTPUT_FILE - except ValueError: + except ValueError as exception: + self.logger.error(exception) return self.exit_codes.ERROR_INVALID_OUTPUT def find_output(self, name): @@ -100,7 +101,7 @@ def serialize_output(self, result, output): """Serialize outputs.""" name = output["name"] - if output["identifier"].upper() == "WORKGRAPH.NAMESPACE": + if output.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE": if isinstance(result, dict): serialized_result = {} for key, value in result.items(): @@ -108,7 +109,7 @@ def serialize_output(self, result, output): full_name_output = self.find_output(full_name) if ( full_name_output - and full_name_output["identifier"].upper() + and full_name_output.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE" ): serialized_result[key] = self.serialize_output( diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 95a21655..68c53fdb 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -1,5 +1,5 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple -from aiida_workgraph.utils import get_executor, serialize_function +from aiida_workgraph.utils import get_executor from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain from aiida import orm from aiida.orm.nodes.process.calculation.calcfunction import CalcFunctionNode @@ -7,6 +7,7 @@ from aiida.engine.processes.ports import PortNamespace import cloudpickle as pickle from aiida_workgraph.task import Task +from aiida_workgraph.orm.function_data import PickledFunction task_types = { CalcFunctionNode: "CALCFUNCTION", @@ -242,7 +243,7 @@ def build_task_from_AiiDA( else outputs ) # get the source code of the function - tdata["executor"] = serialize_function(executor) + tdata["executor"] = PickledFunction(executor).executor # tdata["executor"]["type"] = tdata["task_type"] # print("kwargs: ", kwargs) # add built-in sockets @@ -491,7 +492,7 @@ def generate_tdata( "properties": properties, "inputs": _inputs, "outputs": task_outputs, - "executor": serialize_function(func), + "executor": PickledFunction(func).executor, "catalog": catalog, } if additional_data: diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 6c2d5a95..90b4d89c 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -90,7 +90,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: + task["executor"]["function_source_code_without_decorator"] ) # outputs - output_info = task["outputs"] + function_outputs = task["outputs"] # serialize the kwargs into AiiDA Data function_kwargs = serialize_to_aiida_nodes(function_kwargs) # transfer the args to kwargs @@ -101,7 +101,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: "code": code, "function_kwargs": function_kwargs, "upload_files": new_upload_files, - "output_info": orm.List(output_info), + "function_outputs": orm.List(function_outputs), "metadata": metadata, **kwargs, } diff --git a/aiida_workgraph/orm/function_data.py b/aiida_workgraph/orm/function_data.py new file mode 100644 index 00000000..c279387f --- /dev/null +++ b/aiida_workgraph/orm/function_data.py @@ -0,0 +1,155 @@ +import inspect +import textwrap +from typing import Callable, Dict, Any, get_type_hints, _SpecialForm +from .general_data import GeneralData + + +class PickledFunction(GeneralData): + """Data class to represent a pickled Python function.""" + + def __init__(self, value=None, **kwargs): + """Initialize a PickledFunction node instance. + + :param value: a Python function + """ + super().__init__(**kwargs) + if not callable(value): + raise ValueError("value must be a callable Python function") + self.set_value(value) + self.set_attribute(value) + + def __str__(self): + return ( + f"PickledFunction<{self.base.attributes.get('function_name')}> pk={self.pk}" + ) + + @property + def metadata(self): + """Return a dictionary of metadata.""" + return { + "function_name": self.base.attributes.get("function_name"), + "import_statements": self.base.attributes.get("import_statements"), + "function_source_code": self.base.attributes.get("function_source_code"), + "function_source_code_without_decorator": self.base.attributes.get( + "function_source_code_without_decorator" + ), + "type": "function", + "is_pickle": True, + } + + @property + def executor(self): + """Return the executor for this node.""" + data = self.metadata + with self.base.repository.open(self.FILENAME, mode="rb") as f: + executor = f.read() + data["executor"] = executor + return data + + def set_attribute(self, value): + """Set the contents of this node by pickling the provided function. + + :param value: The Python function to pickle and store. + """ + # Serialize the function and extract metadata + serialized_data = self.serialize_function(value) + + # Store relevant metadata + self.base.attributes.set("function_name", serialized_data["function_name"]) + self.base.attributes.set( + "import_statements", serialized_data["import_statements"] + ) + self.base.attributes.set( + "function_source_code", serialized_data["function_source_code"] + ) + self.base.attributes.set( + "function_source_code_without_decorator", + serialized_data["function_source_code_without_decorator"], + ) + + @classmethod + def serialize_function(cls, func: Callable) -> Dict[str, Any]: + """Serialize a function for storage or transmission.""" + try: + # we need save the source code explicitly, because in the case of jupyter notebook, + # the source code is not saved in the pickle file + source_code = inspect.getsource(func) + # Split the source into lines for processing + source_code_lines = source_code.split("\n") + function_source_code = "\n".join(source_code_lines) + # Find the first line of the actual function definition + for i, line in enumerate(source_code_lines): + if line.strip().startswith("def "): + break + function_source_code_without_decorator = "\n".join(source_code_lines[i:]) + function_source_code_without_decorator = textwrap.dedent( + function_source_code_without_decorator + ) + # we also need to include the necessary imports for the types used in the type hints. + try: + required_imports = cls.get_required_imports(func) + except Exception as e: + required_imports = {} + print( + f"Failed to get required imports for function {func.__name__}: {e}" + ) + # Generate import statements + import_statements = "\n".join( + f"from {module} import {', '.join(types)}" + for module, types in required_imports.items() + ) + except Exception as e: + print(f"Failed to serialize function {func.__name__}: {e}") + function_source_code = "" + function_source_code_without_decorator = "" + import_statements = "" + return { + "function_name": func.__name__, + "function_source_code": function_source_code, + "function_source_code_without_decorator": function_source_code_without_decorator, + "import_statements": import_statements, + } + + @classmethod + def get_required_imports(cls, func: Callable) -> Dict[str, set]: + """Retrieve type hints and the corresponding modules.""" + type_hints = get_type_hints(func) + imports = {} + + def add_imports(type_hint): + if isinstance( + type_hint, _SpecialForm + ): # Handle special forms like Any, Union, Optional + module_name = "typing" + type_name = type_hint._name or str(type_hint) + elif hasattr( + type_hint, "__origin__" + ): # This checks for higher-order types like List, Dict + module_name = type_hint.__module__ + type_name = getattr(type_hint, "_name", None) or getattr( + type_hint.__origin__, "__name__", None + ) + for arg in getattr(type_hint, "__args__", []): + if arg is type(None): # noqa: E721 + continue + add_imports(arg) # Recursively add imports for each argument + elif hasattr(type_hint, "__module__"): + module_name = type_hint.__module__ + type_name = type_hint.__name__ + else: + return # If no module or origin, we can't import it, e.g., for literals + + if type_name is not None: + if module_name not in imports: + imports[module_name] = set() + imports[module_name].add(type_name) + + for _, type_hint in type_hints.items(): + add_imports(type_hint) + + return imports + + +def to_pickled_function(value): + """Convert a Python function to a `PickledFunction` instance.""" + return PickledFunction(value) diff --git a/aiida_workgraph/orm/general_data.py b/aiida_workgraph/orm/general_data.py index 9a317215..bcf86b1e 100644 --- a/aiida_workgraph/orm/general_data.py +++ b/aiida_workgraph/orm/general_data.py @@ -1,6 +1,9 @@ """`Data` sub class to represent any data using pickle.""" from aiida import orm +import sys +import cloudpickle +from pickle import UnpicklingError class Dict(orm.Dict): @@ -16,12 +19,14 @@ def value(self): class GeneralData(orm.Data): - """`Data to represent a pickled value.""" + """Data to represent a pickled value using cloudpickle.""" + + FILENAME = "value.pkl" # Class attribute to store the filename def __init__(self, value=None, **kwargs): - """Initialise a ``General`` node instance. + """Initialize a `GeneralData` node instance. - :param value: list to initialise the ``List`` node from + :param value: raw Python value to initialize the `GeneralData` node from. """ super().__init__(**kwargs) self.set_value(value) @@ -33,53 +38,49 @@ def __str__(self): def value(self): """Return the contents of this node. - :return: a value + :return: The unpickled value. """ return self.get_value() @value.setter def value(self, value): - return self.set_value(value) + self.set_value(value) def get_value(self): - """Return the contents of this node. + """Return the contents of this node, unpickling the stored value. - :return: a value + :return: The unpickled value. """ - import cloudpickle as pickle - - def get_value_from_file(self): - filename = "value.pkl" - # Open a handle in binary read mode as the arrays are written as binary files as well - with self.base.repository.open(filename, mode="rb") as f: - return pickle.loads(f.read()) # pylint: disable=unexpected-keyword-arg - - # Return with proper caching if the node is stored, otherwise always re-read from disk - return get_value_from_file(self) + return self._get_value_from_file() + + def _get_value_from_file(self): + """Read the pickled value from file and return it.""" + try: + with self.base.repository.open(self.FILENAME, mode="rb") as f: + return cloudpickle.loads(f.read()) # Deserialize the value + except (UnpicklingError, ValueError) as e: + raise ImportError( + "Failed to load the pickled value. This may be due to an incompatible pickle protocol. " + "Please ensure that the correct environment and cloudpickle version are being used." + ) from e + except ModuleNotFoundError as e: + raise ImportError( + "Failed to load the pickled value. This may be due to a missing module. " + "Please ensure that the correct environment and cloudpickle version are being used." + ) from e def set_value(self, value): - """Set the contents of this node. + """Set the contents of this node by pickling the provided value. - :param value: the value to set + :param value: The Python value to pickle and store. """ - import cloudpickle as pickle - import sys + # Serialize the value and store it + serialized_value = cloudpickle.dumps(value) + self.base.repository.put_object_from_bytes(serialized_value, self.FILENAME) - self.base.repository.put_object_from_bytes(pickle.dumps(value), "value.pkl") + # Store relevant metadata python_version = f"{sys.version_info.major}.{sys.version_info.minor}" self.base.attributes.set("python_version", python_version) - - def _using_value_reference(self): - """This function tells the class if we are using a list reference. This - means that calls to self.get_value return a reference rather than a copy - of the underlying list and therefore self.set_value need not be called. - This knwoledge is essential to make sure this class is performant. - - Currently the implementation assumes that if the node needs to be - stored then it is using the attributes cache which is a reference. - - :return: True if using self.get_value returns a reference to the - underlying sequence. False otherwise. - :rtype: bool - """ - return self.is_stored + self.base.attributes.set("serializer_module", cloudpickle.__name__) + self.base.attributes.set("serializer_version", cloudpickle.__version__) + self.base.attributes.set("pickle_protocol", cloudpickle.DEFAULT_PROTOCOL) diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 49c4a134..cc04c555 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -526,95 +526,6 @@ def recursive_to_dict(attr_dict): return attr_dict -def get_required_imports(func): - """Retrieve type hints and the corresponding module""" - from typing import get_type_hints, _SpecialForm - - type_hints = get_type_hints(func) - imports = {} - - def add_imports(type_hint): - if isinstance( - type_hint, _SpecialForm - ): # Handle special forms like Any, Union, Optional - module_name = "typing" - type_name = type_hint._name or str(type_hint) - elif hasattr( - type_hint, "__origin__" - ): # This checks for higher-order types like List, Dict - module_name = type_hint.__module__ - type_name = getattr(type_hint, "_name", None) or getattr( - type_hint.__origin__, "__name__", None - ) - for arg in getattr(type_hint, "__args__", []): - if arg is type(None): # noqa: E721 - continue - add_imports(arg) # Recursively add imports for each argument - elif hasattr(type_hint, "__module__"): - module_name = type_hint.__module__ - type_name = type_hint.__name__ - else: - return # If no module or origin, we can't import it, e.g., for literals - - if type_name is not None: - if module_name not in imports: - imports[module_name] = set() - imports[module_name].add(type_name) - - for _, type_hint in type_hints.items(): - add_imports(type_hint) - - return imports - - -def serialize_function(func: Callable) -> Dict[str, Any]: - """Serialize a function for storage or transmission.""" - import inspect - import textwrap - import cloudpickle as pickle - - try: - # we need save the source code explicitly, because in the case of jupyter notebook, - # the source code is not saved in the pickle file - source_code = inspect.getsource(func) - # Split the source into lines for processing - source_code_lines = source_code.split("\n") - function_source_code = "\n".join(source_code_lines) - # Find the first line of the actual function definition - for i, line in enumerate(source_code_lines): - if line.strip().startswith("def "): - break - function_source_code_without_decorator = "\n".join(source_code_lines[i:]) - function_source_code_without_decorator = textwrap.dedent( - function_source_code_without_decorator - ) - # we also need to include the necessary imports for the types used in the type hints. - try: - required_imports = get_required_imports(func) - except Exception as e: - required_imports = {} - print(f"Failed to get required imports for function {func.__name__}: {e}") - # Generate import statements - import_statements = "\n".join( - f"from {module} import {', '.join(types)}" - for module, types in required_imports.items() - ) - except Exception as e: - print(f"Failed to serialize function {func.__name__}: {e}") - function_source_code = "" - function_source_code_without_decorator = "" - import_statements = "" - return { - "executor": pickle.dumps(func), - "type": "function", - "is_pickle": True, - "function_name": func.__name__, - "function_source_code": function_source_code, - "function_source_code_without_decorator": function_source_code_without_decorator, - "import_statements": import_statements, - } - - def workgraph_to_short_json( wgdata: Dict[str, Union[str, List, Dict]] ) -> Dict[str, Union[str, Dict]]: diff --git a/docs/source/built-in/pythonjob.ipynb b/docs/source/built-in/pythonjob.ipynb index 662a85c7..9d86bdd3 100644 --- a/docs/source/built-in/pythonjob.ipynb +++ b/docs/source/built-in/pythonjob.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "c6b83fb5", "metadata": {}, "outputs": [ @@ -33,7 +33,7 @@ "Profile" ] }, - "execution_count": 1, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -2398,7 +2398,98 @@ "}\n", "```\n", "\n", - "Save the configuration file as `workgraph.json` in the aiida configuration directory (by default, `~/.aiida` directory)." + "Save the configuration file as `workgraph.json` in the aiida configuration directory (by default, `~/.aiida` directory).\n", + "\n", + "\n", + "## Use PythonJob outside WorkGraph\n", + "One can use the `PythonJob` task outside the WorkGraph to run a Python function on a remote computer. For example, in a `WorkChain` or run a single `CalcJob` calculation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9a1fa5e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result: {'remote_folder': , 'retrieved': , 'add': }\n", + "Node: uuid: b2057442-452f-4d13-9d64-5e42c72a162e (pk: 107693) (aiida.calculations:workgraph.python)\n" + ] + } + ], + "source": [ + "from aiida import orm, load_profile\n", + "from aiida.engine import run\n", + "from aiida_workgraph.calculations.python import PythonJob\n", + "\n", + "load_profile()\n", + "\n", + "python_code = orm.load_code(\"python3@localhost\")\n", + "\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "result, node = run(PythonJob, code=python_code,\n", + " function=add,\n", + " function_kwargs = {\"x\": orm.Int(1), \"y\": orm.Int(2)},\n", + " function_outputs=[{\"name\": \"add\"}])\n", + "\n", + "print(\"Result: \", result.add.value)\n" + ] + }, + { + "cell_type": "markdown", + "id": "4fb22545", + "metadata": {}, + "source": [ + "You can see more details on any process, including its inputs and outputs, using the verdi command:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "86e74979", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[22mProperty Value\n", + "----------- ------------------------------------\n", + "type PythonJob\n", + "state Finished [0]\n", + "pk 107693\n", + "uuid b2057442-452f-4d13-9d64-5e42c72a162e\n", + "label PythonJob\n", + "description\n", + "ctime 2024-08-15 12:38:26.453629+02:00\n", + "mtime 2024-08-15 12:38:29.440261+02:00\n", + "computer [1] localhost\n", + "\n", + "Inputs PK Type\n", + "------------------ ------ ---------------\n", + "function_kwargs\n", + " x 107689 Int\n", + " y 107690 Int\n", + "code 42316 InstalledCode\n", + "function 107688 PickledFunction\n", + "function_outputs 107691 List\n", + "parent_folder_name 107692 Str\n", + "\n", + "Outputs PK Type\n", + "------------- ------ ----------\n", + "add 107697 Int\n", + "remote_folder 107695 RemoteData\n", + "retrieved 107696 FolderData\u001b[0m\n" + ] + } + ], + "source": [ + "%verdi process show {node.pk}" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 7e8316ca..b733668a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points."aiida.data"] "workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" +"workgraph.pickled_function" = "aiida_workgraph.orm.function_data:PickledFunction" "workgraph.ase.atoms.Atoms" = "aiida_workgraph.orm.atoms:AtomsData" "workgraph.builtins.int" = "aiida.orm.nodes.data.int:Int" "workgraph.builtins.float" = "aiida.orm.nodes.data.float:Float" diff --git a/tests/test_python.py b/tests/test_python.py index ec7ec60e..1043069a 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -76,7 +76,7 @@ def test_PythonJob_typing(): """Test function with typing.""" from numpy import array from ase import Atoms - from aiida_workgraph.utils import get_required_imports + from aiida_workgraph.orm.function_data import PickledFunction from typing import List def generate_structures( @@ -96,14 +96,14 @@ def generate_structures_2( ) -> list[Atoms]: pass - modules = get_required_imports(generate_structures) + modules = PickledFunction.get_required_imports(generate_structures) assert modules == { "ase.atoms": {"Atoms"}, "typing": {"List"}, "builtins": {"list"}, "numpy": {"array"}, } - modules = get_required_imports(generate_structures_2) + modules = PickledFunction.get_required_imports(generate_structures_2) assert modules == {"ase.atoms": {"Atoms"}, "builtins": {"list", "str"}}