Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning existing artifacts from steps #3347

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
from zenml.logger import get_logger
from zenml.logging.step_logging import StepLogsStorageContext, redirected
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.models import (
ArtifactVersionResponse,
PipelineRunResponse,
StepRunResponse,
)
from zenml.models.v2.core.step_run import StepRunInputResponse
from zenml.orchestrators.publish_utils import (
publish_step_run_metadata,
Expand Down Expand Up @@ -536,6 +541,8 @@ def _validate_outputs(
output_type = output_annotation.resolved_annotation
if output_type is Any:
pass
elif isinstance(return_value, ArtifactVersionResponse):
pass
else:
if is_union(get_origin(output_type)):
output_type = get_args(output_type)
Expand Down Expand Up @@ -575,9 +582,15 @@ def _store_output_artifacts(
The IDs of the published output artifacts.
"""
step_context = get_step_context()
artifact_requests = []
artifact_requests = {}

artifacts = {}

for output_name, return_value in output_data.items():
if isinstance(return_value, ArtifactVersionResponse):
artifacts[output_name] = return_value
continue

data_type = type(return_value)
materializer_classes = output_materializers[output_name]
if materializer_classes:
Expand Down Expand Up @@ -652,12 +665,13 @@ def _store_output_artifacts(
save_type=ArtifactSaveType.STEP_OUTPUT,
metadata=user_metadata,
)
artifact_requests.append(artifact_request)
artifact_requests[output_name] = artifact_request

responses = Client().zen_store.batch_create_artifact_versions(
artifact_requests
list(artifact_requests.values())
)
return dict(zip(output_data.keys(), responses))
artifacts.update(dict(zip(artifact_requests.keys(), responses)))
return artifacts

def load_and_run_hook(
self,
Expand Down
Loading