Skip to content

Commit dbcfe3f

Browse files
Feature:4151 Heartbeat activation logic (#4216)
* Heartbeat activation logic * Remove usused compile_spec method * Fix heartbeat status conditions
1 parent 6733792 commit dbcfe3f

File tree

10 files changed

+189
-147
lines changed

10 files changed

+189
-147
lines changed

src/zenml/config/compiler.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,17 @@
2828

2929
from zenml import __version__
3030
from zenml.config.base_settings import BaseSettings, ConfigurationLevel
31-
from zenml.config.pipeline_configurations import PipelineConfiguration
3231
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
3332
from zenml.config.pipeline_spec import OutputSpec, PipelineSpec
3433
from zenml.config.settings_resolver import SettingsResolver
3534
from zenml.config.step_configurations import (
3635
InputSpec,
3736
Step,
37+
StepConfiguration,
3838
StepConfigurationUpdate,
3939
StepSpec,
4040
)
41+
from zenml.enums import StepRuntime
4142
from zenml.environment import get_run_environment_dict
4243
from zenml.exceptions import StackValidationError
4344
from zenml.models import PipelineSnapshotBase
@@ -138,7 +139,7 @@ def compile(
138139
invocation=invocation,
139140
stack=stack,
140141
step_config=(run_configuration.steps or {}).get(invocation_id),
141-
pipeline_configuration=pipeline.configuration,
142+
pipeline=pipeline,
142143
skip_input_validation=skip_input_validation,
143144
)
144145
for invocation_id, invocation in self._get_sorted_invocations(
@@ -177,38 +178,6 @@ def compile(
177178

178179
return snapshot
179180

180-
def compile_spec(self, pipeline: "Pipeline") -> PipelineSpec:
181-
"""Compiles a ZenML pipeline to a pipeline spec.
182-
183-
This method can be used when a pipeline spec is needed but the full
184-
snapshot including stack information is not required.
185-
186-
Args:
187-
pipeline: The pipeline to compile.
188-
189-
Returns:
190-
The compiled pipeline spec.
191-
"""
192-
logger.debug(
193-
"Compiling pipeline spec for pipeline `%s`.", pipeline.name
194-
)
195-
# Copy the pipeline before we connect the steps, so we don't mess with
196-
# the pipeline object/step objects in any way
197-
pipeline = copy.deepcopy(pipeline)
198-
199-
invocations = [
200-
self._get_step_spec(invocation=invocation)
201-
for _, invocation in self._get_sorted_invocations(
202-
pipeline=pipeline
203-
)
204-
]
205-
206-
pipeline_spec = self._compute_pipeline_spec(
207-
pipeline=pipeline, step_specs=invocations
208-
)
209-
logger.debug("Compiled pipeline spec: %s", pipeline_spec)
210-
return pipeline_spec
211-
212181
def _apply_run_configuration(
213182
self, pipeline: "Pipeline", config: PipelineRunConfiguration
214183
) -> None:
@@ -487,11 +456,13 @@ def _filter_and_validate_settings(
487456
def _get_step_spec(
488457
self,
489458
invocation: "StepInvocation",
459+
enable_heartbeat: bool,
490460
) -> StepSpec:
491461
"""Gets the spec for a step invocation.
492462
493463
Args:
494464
invocation: The invocation for which to get the spec.
465+
enable_heartbeat: Whether to enable the heartbeat.
495466
496467
Returns:
497468
The step spec.
@@ -508,14 +479,40 @@ def _get_step_spec(
508479
upstream_steps=sorted(invocation.upstream_steps),
509480
inputs=inputs,
510481
invocation_id=invocation.id,
482+
enable_heartbeat=enable_heartbeat,
511483
)
512484

485+
@staticmethod
486+
def _get_heartbeat_flag(
487+
pipeline: "Pipeline", stack: "Stack", step_config: "StepConfiguration"
488+
) -> bool:
489+
if stack.orchestrator.flavor == "local":
490+
return False
491+
elif not pipeline.is_dynamic:
492+
# containerized static pipeline
493+
return True
494+
else:
495+
# dynamic pipelines
496+
from zenml.execution.pipeline.dynamic.runner import (
497+
get_step_runtime,
498+
)
499+
500+
step_runtime = get_step_runtime(
501+
step_config=step_config,
502+
pipeline_docker_settings=pipeline.configuration.docker_settings,
503+
)
504+
if step_runtime == StepRuntime.ISOLATED:
505+
# dynamic pipelines & isolated execution
506+
return True
507+
# dynamic pipelines & inline execution
508+
return False
509+
513510
def _compile_step_invocation(
514511
self,
515512
invocation: "StepInvocation",
516513
stack: "Stack",
517514
step_config: Optional["StepConfigurationUpdate"],
518-
pipeline_configuration: "PipelineConfiguration",
515+
pipeline: "Pipeline",
519516
skip_input_validation: bool = False,
520517
) -> Step:
521518
"""Compiles a ZenML step.
@@ -524,7 +521,7 @@ def _compile_step_invocation(
524521
invocation: The step invocation to compile.
525522
stack: The stack on which the pipeline will be run.
526523
step_config: Run configuration for the step.
527-
pipeline_configuration: Configuration for the pipeline.
524+
pipeline: Configuration for the pipeline.
528525
skip_input_validation: If True, will skip the input validation.
529526
530527
Returns:
@@ -548,7 +545,6 @@ def _compile_step_invocation(
548545
convert_component_shortcut_settings_keys(
549546
step.configuration.settings, stack=stack
550547
)
551-
step_spec = self._get_step_spec(invocation=invocation)
552548
step_secrets = secret_utils.resolve_and_verify_secrets(
553549
step.configuration.secrets
554550
)
@@ -572,9 +568,18 @@ def _compile_step_invocation(
572568
)
573569
full_step_config = (
574570
step_configuration_overrides.apply_pipeline_configuration(
575-
pipeline_configuration=pipeline_configuration
571+
pipeline_configuration=pipeline.configuration
576572
)
577573
)
574+
575+
step_spec = self._get_step_spec(
576+
invocation=invocation,
577+
enable_heartbeat=self._get_heartbeat_flag(
578+
pipeline=pipeline,
579+
stack=stack,
580+
step_config=full_step_config,
581+
),
582+
)
578583
return Step(
579584
spec=step_spec,
580585
config=full_step_config,

src/zenml/config/step_configurations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ class StepSpec(FrozenBaseModel):
416416
upstream_steps: List[str]
417417
inputs: Dict[str, InputSpec] = {}
418418
invocation_id: str
419+
enable_heartbeat: bool = False
419420

420421
@model_validator(mode="before")
421422
@classmethod

src/zenml/execution/pipeline/dynamic/runner.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from zenml.artifacts.in_memory_cache import InMemoryArtifactCache
3535
from zenml.client import Client
3636
from zenml.config.compiler import Compiler
37-
from zenml.config.step_configurations import Step
3837
from zenml.enums import ExecutionMode, StepRuntime
3938
from zenml.execution.pipeline.dynamic.outputs import (
4039
ArtifactFuture,
@@ -68,7 +67,7 @@
6867

6968
if TYPE_CHECKING:
7069
from zenml.config import DockerSettings
71-
from zenml.config.step_configurations import Step
70+
from zenml.config.step_configurations import Step, StepConfiguration
7271
from zenml.steps import BaseStep
7372

7473

@@ -384,11 +383,12 @@ def _await_and_validate_input(input: Any) -> Any:
384383
model_artifacts_or_metadata={},
385384
client_lazy_loaders={},
386385
)
386+
387387
return Compiler()._compile_step_invocation(
388388
invocation=pipeline.invocations[invocation_id],
389389
stack=Client().active_stack,
390390
step_config=None,
391-
pipeline_configuration=pipeline.configuration,
391+
pipeline=pipeline,
392392
)
393393

394394

@@ -440,7 +440,10 @@ def _should_retry_locally(
440440
if step.config.step_operator:
441441
return True
442442

443-
runtime = get_step_runtime(step, pipeline_docker_settings)
443+
runtime = get_step_runtime(
444+
step_config=step.config,
445+
pipeline_docker_settings=pipeline_docker_settings,
446+
)
444447
if runtime == StepRuntime.INLINE or step.config.step_operator:
445448
return True
446449
else:
@@ -451,29 +454,30 @@ def _should_retry_locally(
451454

452455

453456
def get_step_runtime(
454-
step: "Step", pipeline_docker_settings: "DockerSettings"
457+
step_config: "StepConfiguration",
458+
pipeline_docker_settings: "DockerSettings",
455459
) -> StepRuntime:
456460
"""Determine if a step should be run in process.
457461
458462
Args:
459-
step: The step.
463+
step_config: The step configuration.
460464
pipeline_docker_settings: The Docker settings of the parent pipeline.
461465
462466
Returns:
463467
The runtime for the step.
464468
"""
465-
if step.config.step_operator:
469+
if step_config.step_operator:
466470
return StepRuntime.ISOLATED
467471

468472
if not Client().active_stack.orchestrator.can_run_isolated_steps:
469473
return StepRuntime.INLINE
470474

471-
runtime = step.config.runtime
475+
runtime = step_config.runtime
472476

473477
if runtime is None:
474-
if not step.config.resource_settings.empty:
478+
if not step_config.resource_settings.empty:
475479
runtime = StepRuntime.ISOLATED
476-
elif step.config.docker_settings != pipeline_docker_settings:
480+
elif step_config.docker_settings != pipeline_docker_settings:
477481
runtime = StepRuntime.ISOLATED
478482
else:
479483
runtime = StepRuntime.INLINE

src/zenml/orchestrators/step_launcher.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from zenml.orchestrators import utils as orchestrator_utils
4444
from zenml.orchestrators.step_runner import StepRunner
4545
from zenml.stack import Stack
46-
from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker
4746
from zenml.utils import env_utils, exception_utils, string_utils
4847
from zenml.utils.time_utils import utc_now
4948

@@ -108,23 +107,20 @@ def __init__(
108107
snapshot: PipelineSnapshotResponse,
109108
step: Step,
110109
orchestrator_run_id: str,
111-
heartbeat_enabled: bool = False,
112110
):
113111
"""Initializes the launcher.
114112
115113
Args:
116114
snapshot: The pipeline snapshot.
117115
step: The step to launch.
118116
orchestrator_run_id: The orchestrator pipeline run id.
119-
heartbeat_enabled: Flag - if set will start heartbeat thread worker
120117
121118
Raises:
122119
RuntimeError: If the snapshot has no associated stack.
123120
"""
124121
self._snapshot = snapshot
125122
self._step = step
126123
self._orchestrator_run_id = orchestrator_run_id
127-
self._heartbeat_enabled = heartbeat_enabled
128124

129125
if not snapshot.stack:
130126
raise RuntimeError(
@@ -433,9 +429,6 @@ def _run_step(
433429
step_run: The model of the current step run.
434430
force_write_logs: The context for the step logs.
435431
436-
Raises:
437-
StepHeartBeatTerminationException: if step heartbeat is enabled and the step is remotely stopped.
438-
KeyboardInterrupt: Will capture, evaluate and reraise keyboard interrupts.
439432
"""
440433
from zenml.deployers.server import runtime
441434

@@ -460,19 +453,6 @@ def _run_step(
460453

461454
start_time = time.time()
462455

463-
# To have a cross-platform compatible handling of main thread termination
464-
# we use Python's interrupt_main instead of termination signals (not Windows supported).
465-
# Since interrupt_main raises KeyboardInterrupt we want in this context to capture it
466-
# and handle it as a custom exception.
467-
468-
heartbeat_worker = StepHeartbeatWorker(step_id=step_run.id)
469-
470-
if self._heartbeat_enabled:
471-
logger.info(
472-
"Initiating heartbeat for step: %s", self._invocation_id
473-
)
474-
heartbeat_worker.start()
475-
476456
try:
477457
if self._step.config.step_operator:
478458
step_operator_name = None
@@ -497,7 +477,7 @@ def _run_step(
497477
)
498478

499479
step_runtime = get_step_runtime(
500-
step=self._step,
480+
step_config=self._step.config,
501481
pipeline_docker_settings=self._snapshot.pipeline_configuration.docker_settings,
502482
)
503483

@@ -524,22 +504,11 @@ def _run_step(
524504
self._run_step_with_dynamic_orchestrator(
525505
step_run_info=step_run_info
526506
)
527-
except KeyboardInterrupt:
528-
if heartbeat_worker.is_terminated:
529-
msg = f"Step {self._invocation_id} has been remotely stopped - terminating"
530-
logger.info(msg)
531-
output_utils.remove_artifact_dirs(
532-
artifact_uris=list(output_artifact_uris.values())
533-
)
534-
raise StepHeartBeatTerminationException(msg)
535-
raise
536507
except: # noqa: E722
537508
output_utils.remove_artifact_dirs(
538509
artifact_uris=list(output_artifact_uris.values())
539510
)
540511
raise
541-
finally:
542-
heartbeat_worker.stop()
543512

544513
duration = time.time() - start_time
545514
logger.info(

0 commit comments

Comments
 (0)