|
63 | 63 | from kfp.compiler import Compiler |
64 | 64 | from kfp.dsl.base_component import BaseComponent |
65 | 65 |
|
| 66 | +from zenml import __version__ |
66 | 67 | from zenml.config.resource_settings import ResourceSettings |
67 | 68 | from zenml.constants import ( |
68 | 69 | METADATA_ORCHESTRATOR_LOGS_URL, |
69 | 70 | METADATA_ORCHESTRATOR_RUN_ID, |
70 | 71 | METADATA_ORCHESTRATOR_URL, |
| 72 | + ORCHESTRATOR_DOCKER_IMAGE_KEY, |
71 | 73 | ) |
72 | 74 | from zenml.entrypoints import StepEntrypointConfiguration |
73 | 75 | from zenml.enums import ExecutionStatus, StackComponentType |
74 | 76 | from zenml.integrations.gcp import GCP_ARTIFACT_STORE_FLAVOR |
75 | 77 | from zenml.integrations.gcp.constants import ( |
76 | 78 | GKE_ACCELERATOR_NODE_SELECTOR_CONSTRAINT_LABEL, |
| 79 | + VERTEX_ENDPOINT_SUFFIX, |
77 | 80 | ) |
78 | 81 | from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import ( |
79 | 82 | VertexOrchestratorConfig, |
|
82 | 85 | from zenml.integrations.gcp.google_credentials_mixin import ( |
83 | 86 | GoogleCredentialsMixin, |
84 | 87 | ) |
| 88 | +from zenml.integrations.gcp.utils import ( |
| 89 | + build_job_request, |
| 90 | + monitor_job, |
| 91 | +) |
85 | 92 | from zenml.integrations.gcp.vertex_custom_job_parameters import ( |
86 | 93 | VertexCustomJobParameters, |
87 | 94 | ) |
|
90 | 97 | from zenml.metadata.metadata_types import MetadataType, Uri |
91 | 98 | from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult |
92 | 99 | from zenml.orchestrators.utils import get_orchestrator_run_name |
| 100 | +from zenml.pipelines.dynamic.entrypoint_configuration import ( |
| 101 | + DynamicPipelineEntrypointConfiguration, |
| 102 | +) |
93 | 103 | from zenml.stack.stack_validator import StackValidator |
| 104 | +from zenml.step_operators.step_operator_entrypoint_configuration import ( |
| 105 | + StepOperatorEntrypointConfiguration, |
| 106 | +) |
94 | 107 | from zenml.utils.io_utils import get_global_config_directory |
95 | 108 |
|
96 | 109 | if TYPE_CHECKING: |
97 | 110 | from zenml.config.base_settings import BaseSettings |
| 111 | + from zenml.config.step_run_info import StepRunInfo |
98 | 112 | from zenml.models import ( |
99 | 113 | PipelineRunResponse, |
100 | 114 | PipelineSnapshotResponse, |
@@ -621,6 +635,151 @@ def dynamic_pipeline() -> None: |
621 | 635 | schedule=snapshot.schedule, |
622 | 636 | ) |
623 | 637 |
|
| 638 | + def submit_dynamic_pipeline( |
| 639 | + self, |
| 640 | + snapshot: "PipelineSnapshotResponse", |
| 641 | + stack: "Stack", |
| 642 | + environment: Dict[str, str], |
| 643 | + placeholder_run: Optional["PipelineRunResponse"] = None, |
| 644 | + ) -> Optional[SubmissionResult]: |
| 645 | + """Submits a dynamic pipeline to the orchestrator. |
| 646 | +
|
| 647 | + Args: |
| 648 | + snapshot: The pipeline snapshot to submit. |
| 649 | + stack: The stack the pipeline will run on. |
| 650 | + environment: Environment variables to set in the orchestration |
| 651 | + environment. |
| 652 | + placeholder_run: An optional placeholder run. |
| 653 | +
|
| 654 | + Raises: |
| 655 | + RuntimeError: If the snapshot contains a schedule. |
| 656 | +
|
| 657 | + Returns: |
| 658 | + Optional submission result. |
| 659 | + """ |
| 660 | + if snapshot.schedule: |
| 661 | + raise RuntimeError( |
| 662 | + "Scheduling dynamic pipelines is not supported for the " |
| 663 | + "Vertex orchestrator yet." |
| 664 | + ) |
| 665 | + |
| 666 | + settings = cast( |
| 667 | + VertexOrchestratorSettings, self.get_settings(snapshot) |
| 668 | + ) |
| 669 | + |
| 670 | + command = ( |
| 671 | + DynamicPipelineEntrypointConfiguration.get_entrypoint_command() |
| 672 | + ) |
| 673 | + args = DynamicPipelineEntrypointConfiguration.get_entrypoint_arguments( |
| 674 | + snapshot_id=snapshot.id, |
| 675 | + run_id=placeholder_run.id if placeholder_run else None, |
| 676 | + ) |
| 677 | + |
| 678 | + image = self.get_image(snapshot=snapshot) |
| 679 | + labels = settings.labels.copy() |
| 680 | + labels["source"] = f"zenml-{__version__.replace('.', '_')}" |
| 681 | + |
| 682 | + job_request = build_job_request( |
| 683 | + display_name=get_orchestrator_run_name( |
| 684 | + pipeline_name=snapshot.pipeline_configuration.name |
| 685 | + ), |
| 686 | + image=image, |
| 687 | + entrypoint_command=command + args, |
| 688 | + custom_job_settings=settings.custom_job_parameters |
| 689 | + or VertexCustomJobParameters(), |
| 690 | + resource_settings=snapshot.pipeline_configuration.resource_settings, |
| 691 | + environment=environment, |
| 692 | + labels=labels, |
| 693 | + encryption_spec_key_name=self.config.encryption_spec_key_name, |
| 694 | + service_account=self.config.workload_service_account, |
| 695 | + network=self.config.network, |
| 696 | + ) |
| 697 | + |
| 698 | + credentials, project_id = self._get_authentication() |
| 699 | + client_options = { |
| 700 | + "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX |
| 701 | + } |
| 702 | + client = aiplatform.gapic.JobServiceClient( |
| 703 | + credentials=credentials, client_options=client_options |
| 704 | + ) |
| 705 | + parent = f"projects/{project_id}/locations/{self.config.location}" |
| 706 | + logger.info( |
| 707 | + "Submitting custom job='%s', path='%s' to Vertex AI Training.", |
| 708 | + job_request["display_name"], |
| 709 | + parent, |
| 710 | + ) |
| 711 | + job = client.create_custom_job(parent=parent, custom_job=job_request) |
| 712 | + |
| 713 | + wait_for_completion = None |
| 714 | + if settings.synchronous: |
| 715 | + wait_for_completion = lambda: monitor_job( |
| 716 | + job_id=job.name, |
| 717 | + credentials_source=self, |
| 718 | + client_options=client_options, |
| 719 | + ) |
| 720 | + |
| 721 | + return SubmissionResult( |
| 722 | + wait_for_completion=wait_for_completion, |
| 723 | + ) |
| 724 | + |
| 725 | + def run_isolated_step( |
| 726 | + self, step_run_info: "StepRunInfo", environment: Dict[str, str] |
| 727 | + ) -> None: |
| 728 | + """Runs an isolated step on Vertex. |
| 729 | +
|
| 730 | + Args: |
| 731 | + step_run_info: The step run information. |
| 732 | + environment: The environment variables to set. |
| 733 | + """ |
| 734 | + settings = cast( |
| 735 | + VertexOrchestratorSettings, self.get_settings(step_run_info) |
| 736 | + ) |
| 737 | + |
| 738 | + image = step_run_info.get_image(key=ORCHESTRATOR_DOCKER_IMAGE_KEY) |
| 739 | + command = StepOperatorEntrypointConfiguration.get_entrypoint_command() |
| 740 | + args = StepOperatorEntrypointConfiguration.get_entrypoint_arguments( |
| 741 | + step_name=step_run_info.pipeline_step_name, |
| 742 | + snapshot_id=(step_run_info.snapshot.id), |
| 743 | + step_run_id=str(step_run_info.step_run_id), |
| 744 | + ) |
| 745 | + |
| 746 | + labels = settings.labels.copy() |
| 747 | + labels["source"] = f"zenml-{__version__.replace('.', '_')}" |
| 748 | + |
| 749 | + job_request = build_job_request( |
| 750 | + display_name=f"{step_run_info.run_name}-{step_run_info.pipeline_step_name}", |
| 751 | + image=image, |
| 752 | + entrypoint_command=command + args, |
| 753 | + custom_job_settings=settings.custom_job_parameters |
| 754 | + or VertexCustomJobParameters(), |
| 755 | + resource_settings=step_run_info.config.resource_settings, |
| 756 | + environment=environment, |
| 757 | + labels=labels, |
| 758 | + encryption_spec_key_name=self.config.encryption_spec_key_name, |
| 759 | + service_account=self.config.workload_service_account, |
| 760 | + network=self.config.network, |
| 761 | + ) |
| 762 | + |
| 763 | + credentials, project_id = self._get_authentication() |
| 764 | + client_options = { |
| 765 | + "api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX |
| 766 | + } |
| 767 | + client = aiplatform.gapic.JobServiceClient( |
| 768 | + credentials=credentials, client_options=client_options |
| 769 | + ) |
| 770 | + parent = f"projects/{project_id}/locations/{self.config.location}" |
| 771 | + logger.info( |
| 772 | + "Submitting custom job='%s', path='%s' to Vertex AI Training.", |
| 773 | + job_request["display_name"], |
| 774 | + parent, |
| 775 | + ) |
| 776 | + job = client.create_custom_job(parent=parent, custom_job=job_request) |
| 777 | + monitor_job( |
| 778 | + job_id=job.name, |
| 779 | + credentials_source=self, |
| 780 | + client_options=client_options, |
| 781 | + ) |
| 782 | + |
624 | 783 | def _upload_and_run_pipeline( |
625 | 784 | self, |
626 | 785 | pipeline_name: str, |
|
0 commit comments