Skip to content

Commit ebf41ac

Browse files
committed
Enable running dynamic snapshots
1 parent ce6c862 commit ebf41ac

File tree

6 files changed

+79
-24
lines changed

6 files changed

+79
-24
lines changed

src/zenml/client.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3618,14 +3618,16 @@ def trigger_pipeline(
36183618
run_configuration
36193619
)
36203620

3621-
if run_configuration:
3622-
validate_run_config_is_runnable_from_server(run_configuration)
3623-
36243621
if template_id:
36253622
logger.warning(
36263623
"Triggering a run template is deprecated. Use "
36273624
"`Client().trigger_pipeline(snapshot_id=...)` instead."
36283625
)
3626+
if run_configuration:
3627+
validate_run_config_is_runnable_from_server(
3628+
run_configuration, is_dynamic=False
3629+
)
3630+
36293631
run = self.zen_store.run_template(
36303632
template_id=template_id,
36313633
run_configuration=run_configuration,
@@ -3638,13 +3640,13 @@ def trigger_pipeline(
36383640
"using stack associated with the snapshot instead."
36393641
)
36403642

3641-
snapshot_id = self.get_snapshot(
3643+
snapshot = self.get_snapshot(
36423644
name_id_or_prefix=snapshot_name_or_id,
36433645
pipeline_name_or_id=pipeline_name_or_id,
36443646
project=project,
36453647
allow_prefix_match=False,
36463648
hydrate=False,
3647-
).id
3649+
)
36483650
else:
36493651
if not pipeline_name_or_id:
36503652
raise RuntimeError(
@@ -3692,7 +3694,6 @@ def trigger_pipeline(
36923694
except ValueError:
36933695
continue
36943696

3695-
snapshot_id = snapshot.id
36963697
break
36973698
else:
36983699
raise RuntimeError(
@@ -3708,8 +3709,13 @@ def trigger_pipeline(
37083709
except RuntimeError:
37093710
pass
37103711

3712+
if run_configuration:
3713+
validate_run_config_is_runnable_from_server(
3714+
run_configuration, is_dynamic=snapshot.is_dynamic
3715+
)
3716+
37113717
run = self.zen_store.run_snapshot(
3712-
snapshot_id=snapshot_id,
3718+
snapshot_id=snapshot.id,
37133719
run_request=PipelineSnapshotRunRequest(
37143720
run_configuration=run_configuration,
37153721
step_run=step_run_id,

src/zenml/pipelines/run_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,27 +186,32 @@ def validate_stack_is_runnable_from_server(
186186

187187
def validate_run_config_is_runnable_from_server(
188188
run_configuration: "PipelineRunConfiguration",
189+
is_dynamic: bool,
189190
) -> None:
190191
"""Validates that the run configuration can be used to run from the server.
191192
192193
Args:
193194
run_configuration: The run configuration to validate.
195+
is_dynamic: Whether the snapshot to run is dynamic.
194196
195197
Raises:
196198
ValueError: If there are values in the run configuration that are not
197199
allowed when running a pipeline from the server.
198200
"""
199-
if run_configuration.parameters:
201+
if run_configuration.parameters and not is_dynamic:
200202
raise ValueError(
201-
"Can't set pipeline parameters when running pipeline via Rest API. "
202-
"This likely requires refactoring your pipeline code to use step parameters "
203-
"instead of pipeline parameters. For example, instead of: "
203+
"Can't set pipeline parameters when running a static pipeline via "
204+
"the REST API. You can either use dynamic pipelines for which you "
205+
"can pass pipeline parameters, or refactore your pipeline code to "
206+
"use step parameters instead of pipeline parameters. For example, "
207+
"instead of: "
204208
"```yaml "
205209
"parameters: "
206210
" param1: 1 "
207211
" param2: 2 "
208212
"``` "
209-
"You'll need to modify your pipeline code to pass parameters directly to steps: "
213+
"You'll need to modify your pipeline code to pass parameters "
214+
"directly to steps: "
210215
"```yaml "
211216
"steps: "
212217
" step1: "

src/zenml/zen_server/pipeline_execution/utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ def run_snapshot(
202202

203203
validate_stack_is_runnable_from_server(zen_store=zen_store(), stack=stack)
204204
if request.run_configuration:
205-
validate_run_config_is_runnable_from_server(request.run_configuration)
205+
validate_run_config_is_runnable_from_server(
206+
request.run_configuration, is_dynamic=snapshot.is_dynamic
207+
)
206208

207209
snapshot_request = snapshot_request_from_source_snapshot(
208210
source_snapshot=snapshot,
@@ -494,9 +496,13 @@ def snapshot_request_from_source_snapshot(
494496
Returns:
495497
The generated snapshot request.
496498
"""
499+
pipeline_update_exclude = {"name"}
500+
if not source_snapshot.is_dynamic:
501+
pipeline_update_exclude.add("parameters")
502+
497503
pipeline_update = config.model_dump(
498504
include=set(PipelineConfiguration.model_fields),
499-
exclude={"name", "parameters"},
505+
exclude=pipeline_update_exclude,
500506
exclude_unset=True,
501507
exclude_none=True,
502508
)
@@ -509,6 +515,14 @@ def snapshot_request_from_source_snapshot(
509515
source_snapshot.pipeline_configuration, pipeline_update
510516
)
511517

518+
pipeline_spec = source_snapshot.pipeline_spec
519+
if pipeline_spec and pipeline_configuration.parameters:
520+
# Also include the updated pipeline parameters in the pipeline spec, as
521+
# the frontend and some other code still relies on the parameters in it
522+
pipeline_spec = pipeline_spec.model_copy(
523+
update={"parameters": pipeline_configuration.parameters}
524+
)
525+
512526
steps = {}
513527
step_config_updates = config.steps or {}
514528
for invocation_id, step in source_snapshot.step_configurations.items():
@@ -586,6 +600,7 @@ def snapshot_request_from_source_snapshot(
586600

587601
return PipelineSnapshotRequest(
588602
project=source_snapshot.project_id,
603+
is_dynamic=source_snapshot.is_dynamic,
589604
run_name_template=config.run_name or source_snapshot.run_name_template,
590605
pipeline_configuration=pipeline_configuration,
591606
step_configurations=steps,
@@ -603,7 +618,7 @@ def snapshot_request_from_source_snapshot(
603618
template=template_id,
604619
source_snapshot=source_snapshot_id,
605620
pipeline_version_hash=source_snapshot.pipeline_version_hash,
606-
pipeline_spec=source_snapshot.pipeline_spec,
621+
pipeline_spec=pipeline_spec,
607622
)
608623

609624

src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -482,12 +482,7 @@ def to_model(
482482
The response.
483483
"""
484484
runnable = False
485-
if (
486-
not self.is_dynamic
487-
and self.build
488-
and not self.build.is_local
489-
and self.build.stack_id
490-
):
485+
if self.build and not self.build.is_local and self.build.stack_id:
491486
runnable = True
492487

493488
deployable = False
@@ -553,6 +548,7 @@ def to_model(
553548
)
554549
config_schema = template_utils.generate_config_schema(
555550
snapshot=self,
551+
pipeline_configuration=pipeline_configuration,
556552
step_configurations=all_step_configurations,
557553
)
558554

src/zenml/zen_stores/schemas/run_template_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def to_model(
318318
)
319319
config_schema = template_utils.generate_config_schema(
320320
snapshot=self.source_snapshot,
321+
pipeline_configuration=source_snapshot_model.pipeline_configuration,
321322
step_configurations=source_snapshot_model.step_configurations,
322323
)
323324

src/zenml/zen_stores/template_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,13 @@ def generate_config_template(
116116
for config in steps_configs.values():
117117
config.get("settings", {}).pop("docker", None)
118118

119+
pipeline_config_exclude = {"schedule", "build"}
120+
if not snapshot.is_dynamic:
121+
pipeline_config_exclude.add("parameters")
122+
119123
pipeline_config = pipeline_configuration.model_dump(
120124
include=set(PipelineRunConfiguration.model_fields),
121-
exclude={"schedule", "build", "parameters"},
125+
exclude=pipeline_config_exclude,
122126
exclude_none=True,
123127
exclude_defaults=True,
124128
)
@@ -135,12 +139,14 @@ def generate_config_template(
135139

136140
def generate_config_schema(
137141
snapshot: PipelineSnapshotSchema,
142+
pipeline_configuration: "PipelineConfiguration",
138143
step_configurations: Dict[str, "Step"],
139144
) -> Dict[str, Any]:
140145
"""Generate a run configuration schema for the snapshot.
141146
142147
Args:
143148
snapshot: The snapshot schema.
149+
pipeline_configuration: The pipeline configuration.
144150
step_configurations: The step configurations.
145151
146152
Returns:
@@ -190,13 +196,17 @@ def generate_config_schema(
190196
generic_step_fields: Dict[str, Any] = {}
191197

192198
for key, field_info in StepConfigurationUpdate.model_fields.items():
193-
if key in [
199+
step_config_exclude = [
194200
"name",
195201
"outputs",
196202
"step_operator",
197203
"experiment_tracker",
198204
"parameters",
199-
]:
205+
]
206+
if not snapshot.is_dynamic:
207+
step_config_exclude.append("runtime")
208+
209+
if key in step_config_exclude:
200210
continue
201211

202212
if field_info.annotation == Optional[SourceWithValidator]: # type: ignore[comparison-overlap]
@@ -294,4 +304,26 @@ def generate_config_schema(
294304
FieldInfo(default=None),
295305
)
296306

307+
if snapshot.is_dynamic:
308+
pipeline_parameter_fields: Dict[str, Any] = {}
309+
310+
for parameter_name in pipeline_configuration.parameters or {}:
311+
# Pydantic doesn't allow field names to start with an underscore
312+
sanitized_parameter_name = parameter_name.lstrip("_")
313+
while sanitized_parameter_name in parameter_fields:
314+
sanitized_parameter_name = sanitized_parameter_name + "_"
315+
316+
pipeline_parameter_fields[sanitized_parameter_name] = (
317+
Any,
318+
FieldInfo(default=..., validation_alias=parameter_name),
319+
)
320+
321+
parameters_class = create_model(
322+
"Parameters", **pipeline_parameter_fields
323+
)
324+
top_level_fields["parameters"] = (
325+
parameters_class,
326+
FieldInfo(default=None),
327+
)
328+
297329
return create_model("Result", **top_level_fields).model_json_schema() # type: ignore[no-any-return]

0 commit comments

Comments
 (0)