Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/book/how-to/steps-pipelines/dynamic_pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,17 @@ if __name__ == "__main__":
The `depends_on` parameter tells ZenML which steps can be configured via the YAML file. This is particularly useful when you want to allow users to configure pipeline behavior without modifying code.

### Pass pipeline parameters when running snapshots from the server

When running a snapshot from the server (either via the UI or the SDK/Rest API), you can now pass pipeline parameters for your dynamic pipelines.

For example:
```python
from zenml.client import Client
Client().trigger_pipeline(snapshot_id=<ID>, run_configuration={"parameters": {"my_param": 3}})
```

## Limitations and Known Issues

### Logging
Expand Down
20 changes: 13 additions & 7 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3618,14 +3618,16 @@ def trigger_pipeline(
run_configuration
)

if run_configuration:
validate_run_config_is_runnable_from_server(run_configuration)

if template_id:
logger.warning(
"Triggering a run template is deprecated. Use "
"`Client().trigger_pipeline(snapshot_id=...)` instead."
)
if run_configuration:
validate_run_config_is_runnable_from_server(
run_configuration, is_dynamic=False
)

run = self.zen_store.run_template(
template_id=template_id,
run_configuration=run_configuration,
Expand All @@ -3638,13 +3640,13 @@ def trigger_pipeline(
"using stack associated with the snapshot instead."
)

snapshot_id = self.get_snapshot(
snapshot = self.get_snapshot(
name_id_or_prefix=snapshot_name_or_id,
pipeline_name_or_id=pipeline_name_or_id,
project=project,
allow_prefix_match=False,
hydrate=False,
).id
)
else:
if not pipeline_name_or_id:
raise RuntimeError(
Expand Down Expand Up @@ -3692,7 +3694,6 @@ def trigger_pipeline(
except ValueError:
continue

snapshot_id = snapshot.id
break
else:
raise RuntimeError(
Expand All @@ -3708,8 +3709,13 @@ def trigger_pipeline(
except RuntimeError:
pass

if run_configuration:
validate_run_config_is_runnable_from_server(
run_configuration, is_dynamic=snapshot.is_dynamic
)

run = self.zen_store.run_snapshot(
snapshot_id=snapshot_id,
snapshot_id=snapshot.id,
run_request=PipelineSnapshotRunRequest(
run_configuration=run_configuration,
step_run=step_run_id,
Expand Down
12 changes: 10 additions & 2 deletions src/zenml/execution/pipeline/dynamic/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def __init__(
self._index = index

def result(self) -> OutputArtifact:
"""Get the step run output artifact.
"""Get the output artifact this future represents.

Raises:
RuntimeError: If the future returned an invalid output.

Returns:
The step run output artifact.
The output artifact.
"""
result = self._wrapped.result()
if isinstance(result, OutputArtifact):
Expand Down Expand Up @@ -166,6 +166,14 @@ def artifacts(self) -> StepRunOutputs:
"""
return self._wrapped.result()

def result(self) -> StepRunOutputs:
"""Get the step run outputs this future represents.

Returns:
The step run outputs.
"""
return self._wrapped.result()

def load(self, disable_cache: bool = False) -> Any:
"""Get the step run output artifact data.

Expand Down
15 changes: 10 additions & 5 deletions src/zenml/pipelines/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,32 @@ def validate_stack_is_runnable_from_server(

def validate_run_config_is_runnable_from_server(
run_configuration: "PipelineRunConfiguration",
is_dynamic: bool,
) -> None:
"""Validates that the run configuration can be used to run from the server.
Args:
run_configuration: The run configuration to validate.
is_dynamic: Whether the snapshot to run is dynamic.
Raises:
ValueError: If there are values in the run configuration that are not
allowed when running a pipeline from the server.
"""
if run_configuration.parameters:
if run_configuration.parameters and not is_dynamic:
raise ValueError(
"Can't set pipeline parameters when running pipeline via Rest API. "
"This likely requires refactoring your pipeline code to use step parameters "
"instead of pipeline parameters. For example, instead of: "
"Can't set pipeline parameters when running a static pipeline via "
"the REST API. You can either use dynamic pipelines for which you "
"can pass pipeline parameters, or refactore your pipeline code to "
"use step parameters instead of pipeline parameters. For example, "
"instead of: "
"```yaml "
"parameters: "
" param1: 1 "
" param2: 2 "
"``` "
"You'll need to modify your pipeline code to pass parameters directly to steps: "
"You'll need to modify your pipeline code to pass parameters "
"directly to steps: "
"```yaml "
"steps: "
" step1: "
Expand Down
21 changes: 18 additions & 3 deletions src/zenml/zen_server/pipeline_execution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def run_snapshot(

validate_stack_is_runnable_from_server(zen_store=zen_store(), stack=stack)
if request.run_configuration:
validate_run_config_is_runnable_from_server(request.run_configuration)
validate_run_config_is_runnable_from_server(
request.run_configuration, is_dynamic=snapshot.is_dynamic
)

snapshot_request = snapshot_request_from_source_snapshot(
source_snapshot=snapshot,
Expand Down Expand Up @@ -494,9 +496,13 @@ def snapshot_request_from_source_snapshot(
Returns:
The generated snapshot request.
"""
pipeline_update_exclude = {"name"}
if not source_snapshot.is_dynamic:
pipeline_update_exclude.add("parameters")

pipeline_update = config.model_dump(
include=set(PipelineConfiguration.model_fields),
exclude={"name", "parameters"},
exclude=pipeline_update_exclude,
exclude_unset=True,
exclude_none=True,
)
Expand All @@ -509,6 +515,14 @@ def snapshot_request_from_source_snapshot(
source_snapshot.pipeline_configuration, pipeline_update
)

pipeline_spec = source_snapshot.pipeline_spec
if pipeline_spec and pipeline_configuration.parameters:
# Also include the updated pipeline parameters in the pipeline spec, as
# the frontend and some other code still relies on the parameters in it
pipeline_spec = pipeline_spec.model_copy(
update={"parameters": pipeline_configuration.parameters}
)

steps = {}
step_config_updates = config.steps or {}
for invocation_id, step in source_snapshot.step_configurations.items():
Expand Down Expand Up @@ -586,6 +600,7 @@ def snapshot_request_from_source_snapshot(

return PipelineSnapshotRequest(
project=source_snapshot.project_id,
is_dynamic=source_snapshot.is_dynamic,
run_name_template=config.run_name or source_snapshot.run_name_template,
pipeline_configuration=pipeline_configuration,
step_configurations=steps,
Expand All @@ -603,7 +618,7 @@ def snapshot_request_from_source_snapshot(
template=template_id,
source_snapshot=source_snapshot_id,
pipeline_version_hash=source_snapshot.pipeline_version_hash,
pipeline_spec=source_snapshot.pipeline_spec,
pipeline_spec=pipeline_spec,
)


Expand Down
8 changes: 2 additions & 6 deletions src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,12 +482,7 @@ def to_model(
The response.
"""
runnable = False
if (
not self.is_dynamic
and self.build
and not self.build.is_local
and self.build.stack_id
):
if self.build and not self.build.is_local and self.build.stack_id:
runnable = True

deployable = False
Expand Down Expand Up @@ -553,6 +548,7 @@ def to_model(
)
config_schema = template_utils.generate_config_schema(
snapshot=self,
pipeline_configuration=pipeline_configuration,
step_configurations=all_step_configurations,
)

Expand Down
1 change: 1 addition & 0 deletions src/zenml/zen_stores/schemas/run_template_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def to_model(
)
config_schema = template_utils.generate_config_schema(
snapshot=self.source_snapshot,
pipeline_configuration=source_snapshot_model.pipeline_configuration,
step_configurations=source_snapshot_model.step_configurations,
)

Expand Down
42 changes: 38 additions & 4 deletions src/zenml/zen_stores/template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,13 @@ def generate_config_template(
for config in steps_configs.values():
config.get("settings", {}).pop("docker", None)

pipeline_config_exclude = {"schedule", "build"}
if not snapshot.is_dynamic:
pipeline_config_exclude.add("parameters")

pipeline_config = pipeline_configuration.model_dump(
include=set(PipelineRunConfiguration.model_fields),
exclude={"schedule", "build", "parameters"},
exclude=pipeline_config_exclude,
exclude_none=True,
exclude_defaults=True,
)
Expand All @@ -135,12 +139,14 @@ def generate_config_template(

def generate_config_schema(
snapshot: PipelineSnapshotSchema,
pipeline_configuration: "PipelineConfiguration",
step_configurations: Dict[str, "Step"],
) -> Dict[str, Any]:
"""Generate a run configuration schema for the snapshot.

Args:
snapshot: The snapshot schema.
pipeline_configuration: The pipeline configuration.
step_configurations: The step configurations.

Returns:
Expand Down Expand Up @@ -190,13 +196,17 @@ def generate_config_schema(
generic_step_fields: Dict[str, Any] = {}

for key, field_info in StepConfigurationUpdate.model_fields.items():
if key in [
step_config_exclude = [
"name",
"outputs",
"step_operator",
"experiment_tracker",
"parameters",
]:
]
if not snapshot.is_dynamic:
step_config_exclude.append("runtime")

if key in step_config_exclude:
continue

if field_info.annotation == Optional[SourceWithValidator]: # type: ignore[comparison-overlap]
Expand Down Expand Up @@ -226,7 +236,8 @@ def generate_config_schema(
all_steps: Dict[str, Any] = {}
all_steps_required = False
for step_name, step in step_configurations.items():
step_fields = generic_step_fields.copy()
step_fields: Dict[str, Any] = {}

if step.config.parameters:
parameter_fields: Dict[str, Any] = {}

Expand All @@ -249,6 +260,7 @@ def generate_config_schema(
FieldInfo(default=...),
)

step_fields.update(generic_step_fields)
step_model = create_model(step_name, **step_fields)

# Pydantic doesn't allow field names to start with an underscore
Expand All @@ -275,6 +287,28 @@ def generate_config_schema(

top_level_fields: Dict[str, Any] = {}

if snapshot.is_dynamic:
pipeline_parameter_fields: Dict[str, Any] = {}

for parameter_name in pipeline_configuration.parameters or {}:
# Pydantic doesn't allow field names to start with an underscore
sanitized_parameter_name = parameter_name.lstrip("_")
while sanitized_parameter_name in pipeline_parameter_fields:
sanitized_parameter_name = sanitized_parameter_name + "_"

pipeline_parameter_fields[sanitized_parameter_name] = (
Any,
FieldInfo(default=..., validation_alias=parameter_name),
)

parameters_class = create_model(
"Parameters", **pipeline_parameter_fields
)
top_level_fields["parameters"] = (
parameters_class,
FieldInfo(default=None),
)

for key, field_info in PipelineRunConfiguration.model_fields.items():
if key in ["schedule", "build", "steps", "settings", "parameters"]:
continue
Expand Down
Loading