@@ -144,6 +144,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
144144 """Orchestrator responsible for running pipelines on Vertex AI."""
145145
146146 _pipeline_root : str
147+ _job_service_client : Optional [aiplatform .gapic .JobServiceClient ] = None
147148
148149 @property
149150 def config (self ) -> VertexOrchestratorConfig :
@@ -261,6 +262,25 @@ def pipeline_directory(self) -> str:
261262 """
262263 return os .path .join (self .root_directory , "pipelines" )
263264
265+ def get_job_service_client (self ) -> aiplatform .gapic .JobServiceClient :
266+ """Get the job service client.
267+
268+ Returns:
269+ The job service client.
270+ """
271+ if self .connector_has_expired ():
272+ self ._job_service_client = None
273+
274+ if self ._job_service_client is None :
275+ credentials , _ = self ._get_authentication ()
276+ client_options = {
277+ "api_endpoint" : self .config .location + VERTEX_ENDPOINT_SUFFIX
278+ }
279+ self ._job_service_client = aiplatform .gapic .JobServiceClient (
280+ credentials = credentials , client_options = client_options
281+ )
282+ return self ._job_service_client
283+
264284 def _create_container_component (
265285 self ,
266286 image : str ,
@@ -696,34 +716,38 @@ def submit_dynamic_pipeline(
696716 network = self .config .network ,
697717 )
698718
699- credentials , project_id = self ._get_authentication ()
700- client_options = {
701- "api_endpoint" : self .config .location + VERTEX_ENDPOINT_SUFFIX
702- }
703- client = aiplatform .gapic .JobServiceClient (
704- credentials = credentials , client_options = client_options
719+ client = self .get_job_service_client ()
720+ parent = (
721+ f"projects/{ self .gcp_project_id } /locations/{ self .config .location } "
705722 )
706- parent = f"projects/{ project_id } /locations/{ self .config .location } "
707723 job_model = client .create_custom_job (
708724 parent = parent , custom_job = job_request
709725 )
710726
711- wait_for_completion = None
727+ _wait_for_completion = None
712728 if settings .synchronous :
713- wait_for_completion = lambda : monitor_job (
714- job_id = job_model .name ,
715- credentials_source = self ,
716- client_options = client_options ,
717- )
718729
719- self ._initialize_vertex_client ()
720- job = aiplatform .CustomJob .get (job_model .name )
730+ def _wait_for_completion () -> None :
731+ logger .info ("Waiting for the VertexAI job to finish..." )
732+ monitor_job (
733+ job_id = job_model .name ,
734+ get_client = self .get_job_service_client ,
735+ )
736+ logger .info ("VertexAI job completed successfully." )
737+
738+ credentials , project_id = self ._get_authentication ()
739+ job = aiplatform .CustomJob .get (
740+ job_model .name ,
741+ project = project_id ,
742+ location = self .config .location ,
743+ credentials = credentials ,
744+ )
721745 metadata = self .compute_metadata (job )
722746
723747 logger .info ("View the Vertex job at %s" , job ._dashboard_uri ())
724748
725749 return SubmissionResult (
726- wait_for_completion = wait_for_completion ,
750+ wait_for_completion = _wait_for_completion ,
727751 metadata = metadata ,
728752 )
729753
@@ -765,14 +789,10 @@ def run_isolated_step(
765789 network = self .config .network ,
766790 )
767791
768- credentials , project_id = self ._get_authentication ()
769- client_options = {
770- "api_endpoint" : self .config .location + VERTEX_ENDPOINT_SUFFIX
771- }
772- client = aiplatform .gapic .JobServiceClient (
773- credentials = credentials , client_options = client_options
792+ client = self .get_job_service_client ()
793+ parent = (
794+ f"projects/{ self .gcp_project_id } /locations/{ self .config .location } "
774795 )
775- parent = f"projects/{ project_id } /locations/{ self .config .location } "
776796 logger .info (
777797 "Submitting custom job='%s', path='%s' to Vertex AI Training." ,
778798 job_request ["display_name" ],
@@ -781,8 +801,7 @@ def run_isolated_step(
781801 job = client .create_custom_job (parent = parent , custom_job = job_request )
782802 monitor_job (
783803 job_id = job .name ,
784- credentials_source = self ,
785- client_options = client_options ,
804+ get_client = self .get_job_service_client ,
786805 )
787806
788807 def _upload_and_run_pipeline (
@@ -1060,15 +1079,6 @@ def _configure_container_resources(
10601079
10611080 return dynamic_component
10621081
1063- def _initialize_vertex_client (self ) -> None :
1064- """Initializes the Vertex client."""
1065- credentials , project_id = self ._get_authentication ()
1066- aiplatform .init (
1067- project = project_id ,
1068- location = self .config .location ,
1069- credentials = credentials ,
1070- )
1071-
10721082 def fetch_status (
10731083 self , run : "PipelineRunResponse" , include_steps : bool = False
10741084 ) -> Tuple [
@@ -1102,8 +1112,6 @@ def fetch_status(
11021112 == run .stack .components [StackComponentType .ORCHESTRATOR ][0 ].id
11031113 )
11041114
1105- self ._initialize_vertex_client ()
1106-
11071115 # Fetch the status of the PipelineJob
11081116 if METADATA_ORCHESTRATOR_RUN_ID in run .run_metadata :
11091117 run_id = run .run_metadata [METADATA_ORCHESTRATOR_RUN_ID ]
@@ -1115,8 +1123,14 @@ def fetch_status(
11151123 "the status."
11161124 )
11171125
1126+ credentials , project_id = self ._get_authentication ()
11181127 if run .snapshot and run .snapshot .is_dynamic :
1119- status = aiplatform .CustomJob .get (run_id ).state
1128+ status = aiplatform .CustomJob .get (
1129+ run_id ,
1130+ project = project_id ,
1131+ location = self .config .location ,
1132+ credentials = credentials ,
1133+ ).state
11201134
11211135 if status in [
11221136 JobState .JOB_STATE_QUEUED ,
@@ -1143,7 +1157,12 @@ def fetch_status(
11431157 else :
11441158 pipeline_status = run .status
11451159 else :
1146- status = aiplatform .PipelineJob .get (run_id ).state
1160+ status = aiplatform .PipelineJob .get (
1161+ run_id ,
1162+ project = project_id ,
1163+ location = self .config .location ,
1164+ credentials = credentials ,
1165+ ).state
11471166
11481167 # Map the potential outputs to ZenML ExecutionStatus. Potential values:
11491168 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#
0 commit comments