Skip to content

Commit

Permalink
Add PickledFunction data (#229)
Browse files Browse the repository at this point in the history
- Update the `GeneralData`
- Implement the `PickledFunction` class as a subclass of orm.Data, the class handles 
  - storing and retrieving serialized function
  - maintaining relevant metadata such as the function name, function source code, module name, Python version, and pickle protocol.
- PythonJob adds `function` input, which accepts `PickledFunction` as input.
- Add docs: use the `PythonJob` task outside the WorkGraph to run a Python function on a remote computer. For example, in a `WorkChain` or run a single `CalcJob` calculation.
  • Loading branch information
superstar54 authored Aug 15, 2024
1 parent d6dfb84 commit 21eb8d0
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 159 deletions.
57 changes: 43 additions & 14 deletions aiida_workgraph/calculations/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RemoteData,
to_aiida_type,
)
from aiida_workgraph.orm.function_data import PickledFunction, to_pickled_function


__all__ = ("PythonJob",)
Expand All @@ -31,6 +32,7 @@ class PythonJob(CalcJob):

_DEFAULT_INPUT_FILE = "script.py"
_DEFAULT_OUTPUT_FILE = "aiida.out"
_DEFAULT_PARENT_FOLDER_NAME = "./parent_folder/"

_default_parser = "workgraph.python"

Expand All @@ -41,6 +43,12 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
:param spec: the calculation job process spec to define.
"""
super().define(spec)
spec.input(
"function",
valid_type=PickledFunction,
serializer=to_pickled_function,
required=False,
)
spec.input(
"function_source_code",
valid_type=Str,
Expand All @@ -57,8 +65,9 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
"function_kwargs", valid_type=Data, required=False
) # , serializer=serialize_to_aiida_nodes)
spec.input(
"output_info",
"function_outputs",
valid_type=List,
default=lambda: List(),
required=False,
serializer=to_aiida_type,
help="The information of the output ports",
Expand All @@ -72,7 +81,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
spec.input(
"parent_folder_name",
valid_type=Str,
default=lambda: Str("./parent_folder/"),
required=False,
serializer=to_aiida_type,
help="""Default name of the subfolder that you want to create in the working directory,
Expand Down Expand Up @@ -140,16 +148,36 @@ def _build_process_label(self) -> str:
:returns: The process label to use for ``ProcessNode`` instances.
"""
if self.inputs.process_label:
if "process_label" in self.inputs:
return self.inputs.process_label.value
else:
return f"PythonJob<{self.inputs.function_name.value}>"
data = self.get_function_data()
return f"PythonJob<{data['function_name']}>"

def on_create(self) -> None:
"""Called when a Process is created."""

super().on_create()
self.node.label = self.inputs.process_label.value
self.node.label = self._build_process_label()

def get_function_data(self) -> dict[str, t.Any]:
"""Get the function data.
:returns: The function data.
"""
if "function" in self.inputs:
metadata = self.inputs.function.metadata
metadata["function_source_code"] = (
metadata["import_statements"]
+ "\n"
+ metadata["function_source_code_without_decorator"]
)
return metadata
else:
return {
"function_source_code": self.inputs.function_source_code.value,
"function_name": self.inputs.function_name.value,
}

def prepare_for_submission(self, folder: Folder) -> CalcInfo:
"""Prepare the calculation for submission.
Expand All @@ -169,21 +197,24 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
inputs = dict(self.inputs.function_kwargs)
else:
inputs = {}
# get the value of pickled function
function_source_code = self.inputs.function_source_code.value
if "parent_folder_name" in self.inputs:
parent_folder_name = self.inputs.parent_folder_name.value
else:
parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME
function_data = self.get_function_data()
# create python script to run the function
script = f"""
import pickle
# define the function
{function_source_code}
{function_data["function_source_code"]}
# load the inputs from the pickle file
with open('inputs.pickle', 'rb') as handle:
inputs = pickle.load(handle)
# run the function
result = {self.inputs.function_name.value}(**inputs)
result = {function_data["function_name"]}(**inputs)
# save the result as a pickle file
with open('results.pickle', 'wb') as handle:
pickle.dump(result, handle)
Expand Down Expand Up @@ -213,7 +244,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
(
source.computer.uuid,
str(dirpath),
self.inputs.parent_folder_name.value,
parent_folder_name,
)
)
elif isinstance(source, FolderData):
Expand All @@ -222,12 +253,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
if self.inputs.parent_output_folder is not None
else ""
)
local_copy_list.append(
(source.uuid, dirname, self.inputs.parent_folder_name.value)
)
local_copy_list.append((source.uuid, dirname, parent_folder_name))
elif isinstance(source, SinglefileData):
local_copy_list.append((source.uuid, source.filename, source.filename))
if self.inputs.upload_files:
if "upload_files" in self.inputs:
upload_files = self.inputs.upload_files
for key, source in upload_files.items():
# replace "_dot_" with "." in the key
Expand Down
17 changes: 9 additions & 8 deletions aiida_workgraph/calculations/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class PythonParser(Parser):
def parse(self, **kwargs):
"""Parse the contents of the output files stored in the `retrieved` output node.
The outputs could be a namespce, e.g.,
outputs=[
The function_outputs could be a namespce, e.g.,
function_outputs=[
{"identifier": "workgraph.namespace", "name": "add_multiply"},
{"name": "add_multiply.add"},
{"name": "add_multiply.multiply"},
Expand All @@ -19,11 +19,11 @@ def parse(self, **kwargs):
"""
import pickle

output_info = self.node.inputs.output_info.get_list()
# output_info exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved']
function_outputs = self.node.inputs.function_outputs.get_list()
# function_outputs exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved']
self.output_list = [
data
for data in output_info
for data in function_outputs
if data["name"]
not in [
"_wait",
Expand Down Expand Up @@ -86,7 +86,8 @@ def parse(self, **kwargs):
self.out(output["name"], output["value"])
except OSError:
return self.exit_codes.ERROR_READING_OUTPUT_FILE
except ValueError:
except ValueError as exception:
self.logger.error(exception)
return self.exit_codes.ERROR_INVALID_OUTPUT

def find_output(self, name):
Expand All @@ -100,15 +101,15 @@ def serialize_output(self, result, output):
"""Serialize outputs."""

name = output["name"]
if output["identifier"].upper() == "WORKGRAPH.NAMESPACE":
if output.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE":
if isinstance(result, dict):
serialized_result = {}
for key, value in result.items():
full_name = f"{name}.{key}"
full_name_output = self.find_output(full_name)
if (
full_name_output
and full_name_output["identifier"].upper()
and full_name_output.get("identifier", "Any").upper()
== "WORKGRAPH.NAMESPACE"
):
serialized_result[key] = self.serialize_output(
Expand Down
7 changes: 4 additions & 3 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from aiida_workgraph.utils import get_executor, serialize_function
from aiida_workgraph.utils import get_executor
from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain
from aiida import orm
from aiida.orm.nodes.process.calculation.calcfunction import CalcFunctionNode
from aiida.orm.nodes.process.workflow.workfunction import WorkFunctionNode
from aiida.engine.processes.ports import PortNamespace
import cloudpickle as pickle
from aiida_workgraph.task import Task
from aiida_workgraph.orm.function_data import PickledFunction

task_types = {
CalcFunctionNode: "CALCFUNCTION",
Expand Down Expand Up @@ -242,7 +243,7 @@ def build_task_from_AiiDA(
else outputs
)
# get the source code of the function
tdata["executor"] = serialize_function(executor)
tdata["executor"] = PickledFunction(executor).executor
# tdata["executor"]["type"] = tdata["task_type"]
# print("kwargs: ", kwargs)
# add built-in sockets
Expand Down Expand Up @@ -491,7 +492,7 @@ def generate_tdata(
"properties": properties,
"inputs": _inputs,
"outputs": task_outputs,
"executor": serialize_function(func),
"executor": PickledFunction(func).executor,
"catalog": catalog,
}
if additional_data:
Expand Down
4 changes: 2 additions & 2 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
+ task["executor"]["function_source_code_without_decorator"]
)
# outputs
output_info = task["outputs"]
function_outputs = task["outputs"]
# serialize the kwargs into AiiDA Data
function_kwargs = serialize_to_aiida_nodes(function_kwargs)
# transfer the args to kwargs
Expand All @@ -101,7 +101,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
"code": code,
"function_kwargs": function_kwargs,
"upload_files": new_upload_files,
"output_info": orm.List(output_info),
"function_outputs": orm.List(function_outputs),
"metadata": metadata,
**kwargs,
}
Expand Down
Loading

0 comments on commit 21eb8d0

Please sign in to comment.