diff --git a/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md b/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md new file mode 100644 index 00000000000..9503d61f066 --- /dev/null +++ b/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md @@ -0,0 +1,56 @@ +--- +description: Return existing artifacts from steps. +--- + +# Return existing artifacts from steps + +ZenML allows you to return existing artifact versions that are already registered in the ZenML server from your pipeline steps. This is particularly useful when you want to optimize caching behavior in your pipelines. + +## Understanding caching behavior + +ZenML's caching mechanism uses the IDs of the step input artifacts (among other things) to determine whether a step needs to be re-run or can be cached. By default, when a step produces an output artifact, a new artifact version is registered - even if the output data is identical to a previous run. + +This means that steps downstream of a non-cached step will also need to be re-run, since their input artifact IDs will be different, even if the underlying data hasn't changed. To enable better caching, you can return existing artifact versions from your steps instead of always creating new ones. This is useful if you want to do some computation in early parts of the pipeline that decides whether the remaining steps of the pipeline can be cached. + +```python +from zenml import pipeline, step, log_metadata +from zenml.client import Client +from typing import Annotated + +# We want to always run this step to decide whether the +# downstream steps can be cached, so we disable caching for it +@step(enable_cache=False) +def compute_cache() -> Annotated[int, "cache_key"]: + # Replace this with your custom logic, for example compute a key + # from the date of the latest available data point + cache_key = 27 + + artifact_versions = Client().list_artifact_versions( + sort_by="desc:created", + size=1, + name="cache_key", + run_metadata={"cache_key_value": cache_key}, + ) + + if artifact_versions: + return artifact_versions[0] + else: + # Log the cache key as metadata on the artifact version so we easily + # fetch it later in subsequent runs + log_metadata(metadata={"cache_key_value": cache_key}, infer_artifact=True) + return cache_key + + +@step +def downstream_step(cache_key: int): + ... + +# Enable caching for the pipeline +@pipeline(enable_cache=True) +def my_pipeline(): + cache_key = compute_cache() + downstream_step(cache_key) +``` + + +
ZenML Scarf
diff --git a/docs/book/toc.md b/docs/book/toc.md index ec51f95fc4b..b22ade2094b 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -165,6 +165,7 @@ * [Datasets in ZenML](how-to/data-artifact-management/complex-usecases/datasets.md) * [Manage big data](how-to/data-artifact-management/complex-usecases/manage-big-data.md) * [Skipping materialization](how-to/data-artifact-management/complex-usecases/unmaterialized-artifacts.md) + * [Return existing artifacts from steps](how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md) * [Passing artifacts between pipelines](how-to/data-artifact-management/complex-usecases/passing-artifacts-between-pipelines.md) * [Register Existing Data as a ZenML Artifact](how-to/data-artifact-management/complex-usecases/registering-existing-data.md) * [Visualizing artifacts](how-to/data-artifact-management/visualize-artifacts/README.md) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index f18f1c649a2..3c55e094c55 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -42,6 +42,7 @@ from zenml.logger import get_logger from zenml.logging.step_logging import StepLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer +from zenml.models import ArtifactVersionResponse from zenml.models.v2.core.step_run import StepRunInputResponse from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, @@ -540,12 +541,24 @@ def _validate_outputs( if is_union(get_origin(output_type)): output_type = get_args(output_type) - if not isinstance(return_value, output_type): - raise StepInterfaceError( - f"Wrong type for output '{output_name}' of step " - f"'{step_name}' (expected type: {output_type}, " - f"actual type: {type(return_value)})." + if isinstance(return_value, ArtifactVersionResponse): + artifact_data_type = source_utils.load( + return_value.data_type ) + if not issubclass(artifact_data_type, output_type): + raise StepInterfaceError( + f"Wrong type for artifact returned for output " + f"'{output_name}' of step '{step_name}' (expected " + f"type: {output_type}, actual type: " + f"{artifact_data_type})." + ) + else: + if not isinstance(return_value, output_type): + raise StepInterfaceError( + f"Wrong type for output '{output_name}' of step " + f"'{step_name}' (expected type: {output_type}, " + f"actual type: {type(return_value)})." + ) validated_outputs[output_name] = return_value return validated_outputs @@ -575,9 +588,15 @@ def _store_output_artifacts( The IDs of the published output artifacts. """ step_context = get_step_context() - artifact_requests = [] + artifact_requests = {} + + artifacts = {} for output_name, return_value in output_data.items(): + if isinstance(return_value, ArtifactVersionResponse): + artifacts[output_name] = return_value + continue + data_type = type(return_value) materializer_classes = output_materializers[output_name] if materializer_classes: @@ -652,12 +671,13 @@ def _store_output_artifacts( save_type=ArtifactSaveType.STEP_OUTPUT, metadata=user_metadata, ) - artifact_requests.append(artifact_request) + artifact_requests[output_name] = artifact_request responses = Client().zen_store.batch_create_artifact_versions( - artifact_requests + list(artifact_requests.values()) ) - return dict(zip(output_data.keys(), responses)) + artifacts.update(dict(zip(artifact_requests.keys(), responses))) + return artifacts def load_and_run_hook( self, diff --git a/tests/unit/steps/test_base_step.py b/tests/unit/steps/test_base_step.py index 434cdc8fc42..c42c0ff0f31 100644 --- a/tests/unit/steps/test_base_step.py +++ b/tests/unit/steps/test_base_step.py @@ -19,7 +19,8 @@ from pydantic import BaseModel from typing_extensions import Annotated -from zenml import pipeline, step +from zenml import pipeline, save_artifact, step +from zenml.client import Client from zenml.exceptions import StepInterfaceError from zenml.materializers import BuiltInMaterializer from zenml.materializers.base_materializer import BaseMaterializer @@ -1052,3 +1053,29 @@ def test_pipeline(): with does_not_raise(): test_pipeline() + + +@step +def step_that_returns_artifact_response( + artifact_name: str, artifact_version: Optional[str] = None +) -> int: + return Client().get_artifact_version(artifact_name, artifact_version) + + +def test_artifact_version_as_step_output(clean_client): + """Test passing an artifact version as step output.""" + int_artifact_name = "int_artifact" + save_artifact(1, name=int_artifact_name) + + str_artifact_name = "str_artifact" + save_artifact("asd", name=str_artifact_name) + + @pipeline + def test_pipeline(artifact_name: str): + step_that_returns_artifact_response(artifact_name=artifact_name) + + with does_not_raise(): + test_pipeline(artifact_name=int_artifact_name) + + with pytest.raises(StepInterfaceError): + test_pipeline(artifact_name=str_artifact_name)