From 60fc646b17c104579f85a276eac4a0e4da1858ef Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 6 Feb 2025 12:13:59 +0100 Subject: [PATCH 1/4] Allow returning raw artifact --- src/zenml/orchestrators/step_runner.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index f18f1c649a2..0d7d16e7e4d 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -42,6 +42,11 @@ from zenml.logger import get_logger from zenml.logging.step_logging import StepLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer +from zenml.models import ( + ArtifactVersionResponse, + PipelineRunResponse, + StepRunResponse, +) from zenml.models.v2.core.step_run import StepRunInputResponse from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, @@ -536,6 +541,8 @@ def _validate_outputs( output_type = output_annotation.resolved_annotation if output_type is Any: pass + elif isinstance(return_value, ArtifactVersionResponse): + pass else: if is_union(get_origin(output_type)): output_type = get_args(output_type) @@ -575,9 +582,15 @@ def _store_output_artifacts( The IDs of the published output artifacts. """ step_context = get_step_context() - artifact_requests = [] + artifact_requests = {} + + artifacts = {} for output_name, return_value in output_data.items(): + if isinstance(return_value, ArtifactVersionResponse): + artifacts[output_name] = return_value + continue + data_type = type(return_value) materializer_classes = output_materializers[output_name] if materializer_classes: @@ -652,12 +665,13 @@ def _store_output_artifacts( save_type=ArtifactSaveType.STEP_OUTPUT, metadata=user_metadata, ) - artifact_requests.append(artifact_request) + artifact_requests[output_name] = artifact_request responses = Client().zen_store.batch_create_artifact_versions( - artifact_requests + list(artifact_requests.values()) ) - return dict(zip(output_data.keys(), responses)) + artifacts.update(dict(zip(artifact_requests.keys(), responses))) + return artifacts def load_and_run_hook( self, From 70dfc63c5170105a94622fd36a0fe9ecd7207534 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 6 Feb 2025 14:06:18 +0100 Subject: [PATCH 2/4] Add tests --- src/zenml/orchestrators/step_runner.py | 30 +++++++++++++++----------- tests/unit/steps/test_base_step.py | 30 +++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 0d7d16e7e4d..3c55e094c55 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -42,11 +42,7 @@ from zenml.logger import get_logger from zenml.logging.step_logging import StepLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer -from zenml.models import ( - ArtifactVersionResponse, - PipelineRunResponse, - StepRunResponse, -) +from zenml.models import ArtifactVersionResponse from zenml.models.v2.core.step_run import StepRunInputResponse from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, @@ -541,18 +537,28 @@ def _validate_outputs( output_type = output_annotation.resolved_annotation if output_type is Any: pass - elif isinstance(return_value, ArtifactVersionResponse): - pass else: if is_union(get_origin(output_type)): output_type = get_args(output_type) - if not isinstance(return_value, output_type): - raise StepInterfaceError( - f"Wrong type for output '{output_name}' of step " - f"'{step_name}' (expected type: {output_type}, " - f"actual type: {type(return_value)})." + if isinstance(return_value, ArtifactVersionResponse): + artifact_data_type = source_utils.load( + return_value.data_type ) + if not issubclass(artifact_data_type, output_type): + raise StepInterfaceError( + f"Wrong type for artifact returned for output " + f"'{output_name}' of step '{step_name}' (expected " + f"type: {output_type}, actual type: " + f"{artifact_data_type})." + ) + else: + if not isinstance(return_value, output_type): + raise StepInterfaceError( + f"Wrong type for output '{output_name}' of step " + f"'{step_name}' (expected type: {output_type}, " + f"actual type: {type(return_value)})." + ) validated_outputs[output_name] = return_value return validated_outputs diff --git a/tests/unit/steps/test_base_step.py b/tests/unit/steps/test_base_step.py index 434cdc8fc42..9d8ccd237ec 100644 --- a/tests/unit/steps/test_base_step.py +++ b/tests/unit/steps/test_base_step.py @@ -18,12 +18,12 @@ import pytest from pydantic import BaseModel from typing_extensions import Annotated - -from zenml import pipeline, step +from zenml.client import Client +from zenml import pipeline, step, save_artifact from zenml.exceptions import StepInterfaceError from zenml.materializers import BuiltInMaterializer from zenml.materializers.base_materializer import BaseMaterializer -from zenml.models import ArtifactVersionResponse +from zenml.models import ArtifactVersionResponse, ArtifactVersionRequest from zenml.steps import BaseStep @@ -1052,3 +1052,27 @@ def test_pipeline(): with does_not_raise(): test_pipeline() + + +@step +def step_that_returns_artifact_response(artifact_name: str, artifact_version: Optional[str] = None) -> int: + return Client().get_artifact_version(artifact_name, artifact_version) + + +def test_artifact_version_as_step_output(clean_client): + """Test passing an artifact version as step output.""" + int_artifact_name = "int_artifact" + save_artifact(1, name=int_artifact_name) + + str_artifact_name = "str_artifact" + save_artifact("asd", name=str_artifact_name) + + @pipeline + def test_pipeline(artifact_name: str): + step_that_returns_artifact_response(artifact_name=artifact_name) + + with does_not_raise(): + test_pipeline(artifact_name=int_artifact_name) + + with pytest.raises(StepInterfaceError): + test_pipeline(artifact_name=str_artifact_name) From 04efe9ace64d707dfc9888c15b8b8ccfb3f71255 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 6 Feb 2025 14:38:11 +0100 Subject: [PATCH 3/4] Docs --- .../return-existing-artifacts-from-steps.md | 56 +++++++++++++++++++ docs/book/toc.md | 1 + tests/unit/steps/test_base_step.py | 9 ++- 3 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md diff --git a/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md b/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md new file mode 100644 index 00000000000..21af78e414a --- /dev/null +++ b/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md @@ -0,0 +1,56 @@ +--- +description: Return existing artifacts from steps. +--- + +# Return existing artifacts from steps + +ZenML allows you to return existing artifact versions that are already registered in the ZenML server from your pipeline steps. This is particularly useful when you want to optimize caching behavior in your pipelines. + +## Understanding caching behavior + +ZenML's caching mechanism uses the IDs of the step input artifacts (among other things) to determine whether a step needs to be re-run or can be cached. By default, when a step produces an output artifact, a new artifact version is registered - even if the output data is identical to a previous run. + +This means that steps downstream of a non-cached step will also need to be re-run, since their input artifact IDs will be different, even if the underlying data hasn't changed. To enable better caching, you can return existing artifact versions from your steps instead of always creating new ones. This is useful if you want to do some computation in early parts of the pipeline that decides whether the remaining steps of the pipeline can be cached. + +```python +from zenml import pipeline, step, log_metadata +from zenml.client import Client +from typing import Annotated + +# We want to always run this step to decide whether the +# downstream steps can be cached, so we disable caching for it +@step(enable_cache=False) +def compute_cache() -> Annotated[int, "cache_key"]: + # Replace this with your custom logic, for example compute a key + # from the date of the latest avaialable data point + cache_key = 27 + + artifact_versions = Client().list_artifact_versions( + sort_by="desc:created", + size=1, + name="cache_key", + run_metadata={"cache_key_value": cache_key}, + ) + + if artifact_versions: + return artifact_versions[0] + else: + # Log the cache key as metadata on the artifact version so we easily + # fetch it later in subsequent runs + log_metadata(metadata={"cache_key_value": cache_key}, infer_artifact=True) + return cache_key + + +@step +def downstream_step(cache_key: int): + ... + +# Enable caching for the pipeline +@pipeline(enable_cache=True) +def my_pipeline(): + cache_key = compute_cache() + downstream_step(cache_key) +``` + + +
ZenML Scarf
diff --git a/docs/book/toc.md b/docs/book/toc.md index ec51f95fc4b..b22ade2094b 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -165,6 +165,7 @@ * [Datasets in ZenML](how-to/data-artifact-management/complex-usecases/datasets.md) * [Manage big data](how-to/data-artifact-management/complex-usecases/manage-big-data.md) * [Skipping materialization](how-to/data-artifact-management/complex-usecases/unmaterialized-artifacts.md) + * [Return existing artifacts from steps](how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md) * [Passing artifacts between pipelines](how-to/data-artifact-management/complex-usecases/passing-artifacts-between-pipelines.md) * [Register Existing Data as a ZenML Artifact](how-to/data-artifact-management/complex-usecases/registering-existing-data.md) * [Visualizing artifacts](how-to/data-artifact-management/visualize-artifacts/README.md) diff --git a/tests/unit/steps/test_base_step.py b/tests/unit/steps/test_base_step.py index 9d8ccd237ec..c42c0ff0f31 100644 --- a/tests/unit/steps/test_base_step.py +++ b/tests/unit/steps/test_base_step.py @@ -18,12 +18,13 @@ import pytest from pydantic import BaseModel from typing_extensions import Annotated + +from zenml import pipeline, save_artifact, step from zenml.client import Client -from zenml import pipeline, step, save_artifact from zenml.exceptions import StepInterfaceError from zenml.materializers import BuiltInMaterializer from zenml.materializers.base_materializer import BaseMaterializer -from zenml.models import ArtifactVersionResponse, ArtifactVersionRequest +from zenml.models import ArtifactVersionResponse from zenml.steps import BaseStep @@ -1055,7 +1056,9 @@ def test_pipeline(): @step -def step_that_returns_artifact_response(artifact_name: str, artifact_version: Optional[str] = None) -> int: +def step_that_returns_artifact_response( + artifact_name: str, artifact_version: Optional[str] = None +) -> int: return Client().get_artifact_version(artifact_name, artifact_version) From 334aa7822eea43d4e83bb41a9712e5955c06e577 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 6 Feb 2025 15:12:47 +0100 Subject: [PATCH 4/4] Typo --- .../complex-usecases/return-existing-artifacts-from-steps.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md b/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md index 21af78e414a..9503d61f066 100644 --- a/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md +++ b/docs/book/how-to/data-artifact-management/complex-usecases/return-existing-artifacts-from-steps.md @@ -22,7 +22,7 @@ from typing import Annotated @step(enable_cache=False) def compute_cache() -> Annotated[int, "cache_key"]: # Replace this with your custom logic, for example compute a key - # from the date of the latest avaialable data point + # from the date of the latest available data point cache_key = 27 artifact_versions = Client().list_artifact_versions(