Skip to content

Commit 1e36533

Browse files
committed
Address review comments
1 parent 3b9d1f0 commit 1e36533

File tree

4 files changed

+115
-85
lines changed

4 files changed

+115
-85
lines changed

src/zenml/integrations/gcp/google_credentials_mixin.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class GoogleCredentialsConfigMixin(StackComponentConfig):
4848
class GoogleCredentialsMixin(StackComponent):
4949
"""StackComponent mixin to get Google Cloud Platform credentials."""
5050

51+
_gcp_credentials: Optional["Credentials"] = None
52+
_gcp_project_id: Optional[str] = None
53+
5154
@property
5255
def config(self) -> GoogleCredentialsConfigMixin:
5356
"""Returns the `GoogleCredentialsConfigMixin` config.
@@ -57,6 +60,18 @@ def config(self) -> GoogleCredentialsConfigMixin:
5760
"""
5861
return cast(GoogleCredentialsConfigMixin, self._config)
5962

63+
@property
64+
def gcp_project_id(self) -> str:
65+
"""Get the GCP project ID.
66+
67+
Returns:
68+
The GCP project ID.
69+
"""
70+
if self._gcp_project_id is None:
71+
_, self._gcp_project_id = self._get_authentication()
72+
73+
return self._gcp_project_id
74+
6075
def _get_authentication(self) -> Tuple["Credentials", str]:
6176
"""Get GCP credentials and the project ID associated with the credentials.
6277
@@ -79,6 +94,12 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
7994
GCPServiceConnector,
8095
)
8196

97+
if self.connector_has_expired():
98+
self._gcp_credentials = None
99+
100+
if self._gcp_credentials and self._gcp_project_id:
101+
return self._gcp_credentials, self._gcp_project_id
102+
82103
connector = self.get_connector()
83104
if connector:
84105
credentials = connector.connect()
@@ -90,6 +111,8 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
90111
"trying to use the linked connector, but got "
91112
f"{type(credentials)}."
92113
)
114+
self._gcp_credentials = credentials
115+
self._gcp_project_id = connector.config.gcp_project_id
93116
return credentials, connector.config.gcp_project_id
94117

95118
if self.config.service_account_path:
@@ -111,4 +134,6 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
111134
# If the project was set in the configuration, use it. Otherwise, use
112135
# the project that was used to authenticate.
113136
project_id = self.config.project if self.config.project else project_id
137+
self._gcp_credentials = credentials
138+
self._gcp_project_id = project_id
114139
return credentials, project_id

src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
144144
"""Orchestrator responsible for running pipelines on Vertex AI."""
145145

146146
_pipeline_root: str
147+
_job_service_client: Optional[aiplatform.gapic.JobServiceClient] = None
147148

148149
@property
149150
def config(self) -> VertexOrchestratorConfig:
@@ -261,6 +262,25 @@ def pipeline_directory(self) -> str:
261262
"""
262263
return os.path.join(self.root_directory, "pipelines")
263264

265+
def get_job_service_client(self) -> aiplatform.gapic.JobServiceClient:
266+
"""Get the job service client.
267+
268+
Returns:
269+
The job service client.
270+
"""
271+
if self.connector_has_expired():
272+
self._job_service_client = None
273+
274+
if self._job_service_client is None:
275+
credentials, _ = self._get_authentication()
276+
client_options = {
277+
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
278+
}
279+
self._job_service_client = aiplatform.gapic.JobServiceClient(
280+
credentials=credentials, client_options=client_options
281+
)
282+
return self._job_service_client
283+
264284
def _create_container_component(
265285
self,
266286
image: str,
@@ -696,34 +716,38 @@ def submit_dynamic_pipeline(
696716
network=self.config.network,
697717
)
698718

699-
credentials, project_id = self._get_authentication()
700-
client_options = {
701-
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
702-
}
703-
client = aiplatform.gapic.JobServiceClient(
704-
credentials=credentials, client_options=client_options
719+
client = self.get_job_service_client()
720+
parent = (
721+
f"projects/{self.gcp_project_id}/locations/{self.config.location}"
705722
)
706-
parent = f"projects/{project_id}/locations/{self.config.location}"
707723
job_model = client.create_custom_job(
708724
parent=parent, custom_job=job_request
709725
)
710726

711-
wait_for_completion = None
727+
_wait_for_completion = None
712728
if settings.synchronous:
713-
wait_for_completion = lambda: monitor_job(
714-
job_id=job_model.name,
715-
credentials_source=self,
716-
client_options=client_options,
717-
)
718729

719-
self._initialize_vertex_client()
720-
job = aiplatform.CustomJob.get(job_model.name)
730+
def _wait_for_completion() -> None:
731+
logger.info("Waiting for the VertexAI job to finish...")
732+
monitor_job(
733+
job_id=job_model.name,
734+
get_client=self.get_job_service_client,
735+
)
736+
logger.info("VertexAI job completed successfully.")
737+
738+
credentials, project_id = self._get_authentication()
739+
job = aiplatform.CustomJob.get(
740+
job_model.name,
741+
project=project_id,
742+
location=self.config.location,
743+
credentials=credentials,
744+
)
721745
metadata = self.compute_metadata(job)
722746

723747
logger.info("View the Vertex job at %s", job._dashboard_uri())
724748

725749
return SubmissionResult(
726-
wait_for_completion=wait_for_completion,
750+
wait_for_completion=_wait_for_completion,
727751
metadata=metadata,
728752
)
729753

@@ -765,14 +789,10 @@ def run_isolated_step(
765789
network=self.config.network,
766790
)
767791

768-
credentials, project_id = self._get_authentication()
769-
client_options = {
770-
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
771-
}
772-
client = aiplatform.gapic.JobServiceClient(
773-
credentials=credentials, client_options=client_options
792+
client = self.get_job_service_client()
793+
parent = (
794+
f"projects/{self.gcp_project_id}/locations/{self.config.location}"
774795
)
775-
parent = f"projects/{project_id}/locations/{self.config.location}"
776796
logger.info(
777797
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
778798
job_request["display_name"],
@@ -781,8 +801,7 @@ def run_isolated_step(
781801
job = client.create_custom_job(parent=parent, custom_job=job_request)
782802
monitor_job(
783803
job_id=job.name,
784-
credentials_source=self,
785-
client_options=client_options,
804+
get_client=self.get_job_service_client,
786805
)
787806

788807
def _upload_and_run_pipeline(
@@ -1060,15 +1079,6 @@ def _configure_container_resources(
10601079

10611080
return dynamic_component
10621081

1063-
def _initialize_vertex_client(self) -> None:
1064-
"""Initializes the Vertex client."""
1065-
credentials, project_id = self._get_authentication()
1066-
aiplatform.init(
1067-
project=project_id,
1068-
location=self.config.location,
1069-
credentials=credentials,
1070-
)
1071-
10721082
def fetch_status(
10731083
self, run: "PipelineRunResponse", include_steps: bool = False
10741084
) -> Tuple[
@@ -1102,8 +1112,6 @@ def fetch_status(
11021112
== run.stack.components[StackComponentType.ORCHESTRATOR][0].id
11031113
)
11041114

1105-
self._initialize_vertex_client()
1106-
11071115
# Fetch the status of the PipelineJob
11081116
if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
11091117
run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
@@ -1115,8 +1123,14 @@ def fetch_status(
11151123
"the status."
11161124
)
11171125

1126+
credentials, project_id = self._get_authentication()
11181127
if run.snapshot and run.snapshot.is_dynamic:
1119-
status = aiplatform.CustomJob.get(run_id).state
1128+
status = aiplatform.CustomJob.get(
1129+
run_id,
1130+
project=project_id,
1131+
location=self.config.location,
1132+
credentials=credentials,
1133+
).state
11201134

11211135
if status in [
11221136
JobState.JOB_STATE_QUEUED,
@@ -1143,7 +1157,12 @@ def fetch_status(
11431157
else:
11441158
pipeline_status = run.status
11451159
else:
1146-
status = aiplatform.PipelineJob.get(run_id).state
1160+
status = aiplatform.PipelineJob.get(
1161+
run_id,
1162+
project=project_id,
1163+
location=self.config.location,
1164+
credentials=credentials,
1165+
).state
11471166

11481167
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
11491168
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#

src/zenml/integrations/gcp/step_operators/vertex_step_operator.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
5757
ZenML entrypoint command in it.
5858
"""
5959

60+
_job_service_client: Optional[aiplatform.gapic.JobServiceClient] = None
61+
6062
def __init__(self, *args: Any, **kwargs: Any) -> None:
6163
"""Initializes the step operator and validates the accelerator type.
6264
@@ -150,6 +152,25 @@ def get_docker_builds(
150152

151153
return builds
152154

155+
def get_job_service_client(self) -> aiplatform.gapic.JobServiceClient:
156+
"""Get the job service client.
157+
158+
Returns:
159+
The job service client.
160+
"""
161+
if self.connector_has_expired():
162+
self._job_service_client = None
163+
164+
if self._job_service_client is None:
165+
credentials, _ = self._get_authentication()
166+
client_options = {
167+
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
168+
}
169+
self._job_service_client = aiplatform.gapic.JobServiceClient(
170+
credentials=credentials, client_options=client_options
171+
)
172+
return self._job_service_client
173+
153174
def launch(
154175
self,
155176
info: "StepRunInfo",
@@ -193,15 +214,10 @@ def launch(
193214
)
194215
logger.debug("Vertex AI Job=%s", job_request)
195216

196-
credentials, project_id = self._get_authentication()
197-
client_options = {
198-
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
199-
}
200-
client = aiplatform.gapic.JobServiceClient(
201-
credentials=credentials, client_options=client_options
217+
client = self.get_job_service_client()
218+
parent = (
219+
f"projects/{self.gcp_project_id}/locations/{self.config.region}"
202220
)
203-
204-
parent = f"projects/{project_id}/locations/{self.config.region}"
205221
logger.info(
206222
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
207223
job_request["display_name"],
@@ -215,6 +231,5 @@ def launch(
215231

216232
monitor_job(
217233
job_id=response.name,
218-
credentials_source=self,
219-
client_options=client_options,
234+
get_client=self.get_job_service_client,
220235
)

src/zenml/integrations/gcp/utils.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Vertex utilities."""
1515

1616
import time
17-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
17+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
1818

1919
from google.api_core.exceptions import ServerError
2020
from google.cloud import aiplatform
@@ -25,9 +25,6 @@
2525
VERTEX_JOB_STATES_COMPLETED,
2626
VERTEX_JOB_STATES_FAILED,
2727
)
28-
from zenml.integrations.gcp.google_credentials_mixin import (
29-
GoogleCredentialsMixin,
30-
)
3128
from zenml.integrations.gcp.vertex_custom_job_parameters import (
3229
VertexCustomJobParameters,
3330
)
@@ -55,59 +52,33 @@ def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None:
5552
)
5653

5754

58-
def get_job_service_client(
59-
credentials_source: GoogleCredentialsMixin,
60-
client_options: Optional[Dict[str, Any]] = None,
61-
) -> aiplatform.gapic.JobServiceClient:
62-
"""Gets a job service client.
63-
64-
Args:
65-
credentials_source: The component that provides the credentials to
66-
access the job.
67-
client_options: The client options to use for the job service client.
68-
69-
Returns:
70-
A job service client.
71-
"""
72-
credentials, _ = credentials_source._get_authentication()
73-
return aiplatform.gapic.JobServiceClient(
74-
credentials=credentials, client_options=client_options
75-
)
76-
77-
7855
def monitor_job(
7956
job_id: str,
80-
credentials_source: GoogleCredentialsMixin,
81-
client_options: Optional[Dict[str, Any]] = None,
57+
get_client: Callable[[], aiplatform.gapic.JobServiceClient],
8258
) -> None:
8359
"""Monitors a job until it is completed.
8460
8561
Args:
8662
job_id: The ID of the job to monitor.
87-
credentials_source: The component that provides the credentials to
88-
access the job.
89-
client_options: The client options to use for the job service client.
63+
get_client: A function that returns an authenticated job service client.
9064
9165
Raises:
9266
RuntimeError: If the job fails.
9367
"""
9468
retry_count = 0
95-
client = get_job_service_client(
96-
credentials_source=credentials_source, client_options=client_options
97-
)
69+
client = get_client()
9870

9971
while True:
10072
time.sleep(POLLING_INTERVAL_IN_SECONDS)
101-
if credentials_source.connector_has_expired():
102-
client = get_job_service_client(
103-
credentials_source=credentials_source,
104-
client_options=client_options,
105-
)
73+
# Fetch a fresh client in case the credentials have expired
74+
client = get_client()
10675

10776
try:
10877
response = client.get_custom_job(name=job_id)
10978
retry_count = 0
11079
except (ConnectionError, ServerError) as err:
80+
# Retry on connection errors, see also
81+
# https://github.com/googleapis/google-api-python-client/issues/218
11182
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
11283
retry_count += 1
11384
logger.warning(

0 commit comments

Comments
 (0)