Skip to content

Commit b922f0d

Browse files
committed
Add support for dynamic pipelines to the Vertex orchestrator
1 parent 6733792 commit b922f0d

File tree

3 files changed

+392
-134
lines changed

3 files changed

+392
-134
lines changed

src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,20 @@
6363
from kfp.compiler import Compiler
6464
from kfp.dsl.base_component import BaseComponent
6565

66+
from zenml import __version__
6667
from zenml.config.resource_settings import ResourceSettings
6768
from zenml.constants import (
6869
METADATA_ORCHESTRATOR_LOGS_URL,
6970
METADATA_ORCHESTRATOR_RUN_ID,
7071
METADATA_ORCHESTRATOR_URL,
72+
ORCHESTRATOR_DOCKER_IMAGE_KEY,
7173
)
7274
from zenml.entrypoints import StepEntrypointConfiguration
7375
from zenml.enums import ExecutionStatus, StackComponentType
7476
from zenml.integrations.gcp import GCP_ARTIFACT_STORE_FLAVOR
7577
from zenml.integrations.gcp.constants import (
7678
GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL,
79+
VERTEX_ENDPOINT_SUFFIX,
7780
)
7881
from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
7982
VertexOrchestratorConfig,
@@ -82,6 +85,10 @@
8285
from zenml.integrations.gcp.google_credentials_mixin import (
8386
GoogleCredentialsMixin,
8487
)
88+
from zenml.integrations.gcp.utils import (
89+
build_job_request,
90+
monitor_job,
91+
)
8592
from zenml.integrations.gcp.vertex_custom_job_parameters import (
8693
VertexCustomJobParameters,
8794
)
@@ -90,11 +97,18 @@
9097
from zenml.metadata.metadata_types import MetadataType, Uri
9198
from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult
9299
from zenml.orchestrators.utils import get_orchestrator_run_name
100+
from zenml.pipelines.dynamic.entrypoint_configuration import (
101+
DynamicPipelineEntrypointConfiguration,
102+
)
93103
from zenml.stack.stack_validator import StackValidator
104+
from zenml.step_operators.step_operator_entrypoint_configuration import (
105+
StepOperatorEntrypointConfiguration,
106+
)
94107
from zenml.utils.io_utils import get_global_config_directory
95108

96109
if TYPE_CHECKING:
97110
from zenml.config.base_settings import BaseSettings
111+
from zenml.config.step_run_info import StepRunInfo
98112
from zenml.models import (
99113
PipelineRunResponse,
100114
PipelineSnapshotResponse,
@@ -621,6 +635,151 @@ def dynamic_pipeline() -> None:
621635
schedule=snapshot.schedule,
622636
)
623637

638+
def submit_dynamic_pipeline(
639+
self,
640+
snapshot: "PipelineSnapshotResponse",
641+
stack: "Stack",
642+
environment: Dict[str, str],
643+
placeholder_run: Optional["PipelineRunResponse"] = None,
644+
) -> Optional[SubmissionResult]:
645+
"""Submits a dynamic pipeline to the orchestrator.
646+
647+
Args:
648+
snapshot: The pipeline snapshot to submit.
649+
stack: The stack the pipeline will run on.
650+
environment: Environment variables to set in the orchestration
651+
environment.
652+
placeholder_run: An optional placeholder run.
653+
654+
Raises:
655+
RuntimeError: If the snapshot contains a schedule.
656+
657+
Returns:
658+
Optional submission result.
659+
"""
660+
if snapshot.schedule:
661+
raise RuntimeError(
662+
"Scheduling dynamic pipelines is not supported for the "
663+
"Vertex orchestrator yet."
664+
)
665+
666+
settings = cast(
667+
VertexOrchestratorSettings, self.get_settings(snapshot)
668+
)
669+
670+
command = (
671+
DynamicPipelineEntrypointConfiguration.get_entrypoint_command()
672+
)
673+
args = DynamicPipelineEntrypointConfiguration.get_entrypoint_arguments(
674+
snapshot_id=snapshot.id,
675+
run_id=placeholder_run.id if placeholder_run else None,
676+
)
677+
678+
image = self.get_image(snapshot=snapshot)
679+
labels = settings.labels.copy()
680+
labels["source"] = f"zenml-{__version__.replace('.', '_')}"
681+
682+
job_request = build_job_request(
683+
display_name=get_orchestrator_run_name(
684+
pipeline_name=snapshot.pipeline_configuration.name
685+
),
686+
image=image,
687+
entrypoint_command=command + args,
688+
custom_job_settings=settings.custom_job_parameters
689+
or VertexCustomJobParameters(),
690+
resource_settings=snapshot.pipeline_configuration.resource_settings,
691+
environment=environment,
692+
labels=labels,
693+
encryption_spec_key_name=self.config.encryption_spec_key_name,
694+
service_account=self.config.workload_service_account,
695+
network=self.config.network,
696+
)
697+
698+
credentials, project_id = self._get_authentication()
699+
client_options = {
700+
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
701+
}
702+
client = aiplatform.gapic.JobServiceClient(
703+
credentials=credentials, client_options=client_options
704+
)
705+
parent = f"projects/{project_id}/locations/{self.config.location}"
706+
logger.info(
707+
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
708+
job_request["display_name"],
709+
parent,
710+
)
711+
job = client.create_custom_job(parent=parent, custom_job=job_request)
712+
713+
wait_for_completion = None
714+
if settings.synchronous:
715+
wait_for_completion = lambda: monitor_job(
716+
job_id=job.name,
717+
credentials_source=self,
718+
client_options=client_options,
719+
)
720+
721+
return SubmissionResult(
722+
wait_for_completion=wait_for_completion,
723+
)
724+
725+
def run_isolated_step(
726+
self, step_run_info: "StepRunInfo", environment: Dict[str, str]
727+
) -> None:
728+
"""Runs an isolated step on Kubernetes.
729+
730+
Args:
731+
step_run_info: The step run information.
732+
environment: The environment variables to set.
733+
"""
734+
settings = cast(
735+
VertexOrchestratorSettings, self.get_settings(step_run_info)
736+
)
737+
738+
image = step_run_info.get_image(key=ORCHESTRATOR_DOCKER_IMAGE_KEY)
739+
command = StepOperatorEntrypointConfiguration.get_entrypoint_command()
740+
args = StepOperatorEntrypointConfiguration.get_entrypoint_arguments(
741+
step_name=step_run_info.pipeline_step_name,
742+
snapshot_id=(step_run_info.snapshot.id),
743+
step_run_id=str(step_run_info.step_run_id),
744+
)
745+
746+
labels = settings.labels.copy()
747+
labels["source"] = f"zenml-{__version__.replace('.', '_')}"
748+
749+
job_request = build_job_request(
750+
display_name=f"{step_run_info.run_name}-{step_run_info.pipeline_step_name}",
751+
image=image,
752+
entrypoint_command=command + args,
753+
custom_job_settings=settings.custom_job_parameters
754+
or VertexCustomJobParameters(),
755+
resource_settings=step_run_info.config.resource_settings,
756+
environment=environment,
757+
labels=labels,
758+
encryption_spec_key_name=self.config.encryption_spec_key_name,
759+
service_account=self.config.workload_service_account,
760+
network=self.config.network,
761+
)
762+
763+
credentials, project_id = self._get_authentication()
764+
client_options = {
765+
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
766+
}
767+
client = aiplatform.gapic.JobServiceClient(
768+
credentials=credentials, client_options=client_options
769+
)
770+
parent = f"projects/{project_id}/locations/{self.config.location}"
771+
logger.info(
772+
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
773+
job_request["display_name"],
774+
parent,
775+
)
776+
job = client.create_custom_job(parent=parent, custom_job=job_request)
777+
monitor_job(
778+
job_id=job.name,
779+
credentials_source=self,
780+
client_options=client_options,
781+
)
782+
624783
def _upload_and_run_pipeline(
625784
self,
626785
pipeline_name: str,

0 commit comments

Comments
 (0)