Skip to content

Commit c768d17

Browse files
committed
Address review comments
1 parent 3b9d1f0 commit c768d17

File tree

4 files changed

+106
-80
lines changed

4 files changed

+106
-80
lines changed

src/zenml/integrations/gcp/google_credentials_mixin.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class GoogleCredentialsConfigMixin(StackComponentConfig):
4848
class GoogleCredentialsMixin(StackComponent):
4949
"""StackComponent mixin to get Google Cloud Platform credentials."""
5050

51+
_gcp_credentials: Optional[Credentials] = None
52+
_gcp_project_id: Optional[str] = None
53+
5154
@property
5255
def config(self) -> GoogleCredentialsConfigMixin:
5356
"""Returns the `GoogleCredentialsConfigMixin` config.
@@ -57,6 +60,18 @@ def config(self) -> GoogleCredentialsConfigMixin:
5760
"""
5861
return cast(GoogleCredentialsConfigMixin, self._config)
5962

63+
@property
64+
def gcp_project_id(self) -> str:
65+
"""Get the GCP project ID.
66+
67+
Returns:
68+
The GCP project ID.
69+
"""
70+
if self._gcp_project_id is None:
71+
_, self._gcp_project_id = self._get_authentication()
72+
73+
return self._gcp_project_id
74+
6075
def _get_authentication(self) -> Tuple["Credentials", str]:
6176
"""Get GCP credentials and the project ID associated with the credentials.
6277
@@ -79,6 +94,12 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
7994
GCPServiceConnector,
8095
)
8196

97+
if self.connector_has_expired():
98+
self._gcp_credentials = None
99+
100+
if self._gcp_credentials and self._gcp_project_id:
101+
return self._gcp_credentials, self._gcp_project_id
102+
82103
connector = self.get_connector()
83104
if connector:
84105
credentials = connector.connect()
@@ -90,6 +111,8 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
90111
"trying to use the linked connector, but got "
91112
f"{type(credentials)}."
92113
)
114+
self._gcp_credentials = credentials
115+
self._gcp_project_id = connector.config.gcp_project_id
93116
return credentials, connector.config.gcp_project_id
94117

95118
if self.config.service_account_path:
@@ -111,4 +134,6 @@ def _get_authentication(self) -> Tuple["Credentials", str]:
111134
# If the project was set in the configuration, use it. Otherwise, use
112135
# the project that was used to authenticate.
113136
project_id = self.config.project if self.config.project else project_id
137+
self._gcp_credentials = credentials
138+
self._gcp_project_id = project_id
114139
return credentials, project_id

src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
144144
"""Orchestrator responsible for running pipelines on Vertex AI."""
145145

146146
_pipeline_root: str
147+
_job_service_client: Optional[aiplatform.gapic.JobServiceClient] = None
147148

148149
@property
149150
def config(self) -> VertexOrchestratorConfig:
@@ -261,6 +262,25 @@ def pipeline_directory(self) -> str:
261262
"""
262263
return os.path.join(self.root_directory, "pipelines")
263264

265+
def get_job_service_client(self) -> aiplatform.gapic.JobServiceClient:
266+
"""Get the job service client.
267+
268+
Returns:
269+
The job service client.
270+
"""
271+
if self.connector_has_expired():
272+
self._job_service_client = None
273+
274+
if self._job_service_client is None:
275+
credentials, _ = self._get_authentication()
276+
client_options = {
277+
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
278+
}
279+
self._job_service_client = aiplatform.gapic.JobServiceClient(
280+
credentials=credentials, client_options=client_options
281+
)
282+
return self._job_service_client
283+
264284
def _create_container_component(
265285
self,
266286
image: str,
@@ -696,14 +716,10 @@ def submit_dynamic_pipeline(
696716
network=self.config.network,
697717
)
698718

699-
credentials, project_id = self._get_authentication()
700-
client_options = {
701-
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
702-
}
703-
client = aiplatform.gapic.JobServiceClient(
704-
credentials=credentials, client_options=client_options
719+
client = self.get_job_service_client()
720+
parent = (
721+
f"projects/{self.gcp_project_id}/locations/{self.config.location}"
705722
)
706-
parent = f"projects/{project_id}/locations/{self.config.location}"
707723
job_model = client.create_custom_job(
708724
parent=parent, custom_job=job_request
709725
)
@@ -712,12 +728,16 @@ def submit_dynamic_pipeline(
712728
if settings.synchronous:
713729
wait_for_completion = lambda: monitor_job(
714730
job_id=job_model.name,
715-
credentials_source=self,
716-
client_options=client_options,
731+
get_client=self.get_job_service_client,
717732
)
718733

719-
self._initialize_vertex_client()
720-
job = aiplatform.CustomJob.get(job_model.name)
734+
credentials, project_id = self._get_authentication()
735+
job = aiplatform.CustomJob.get(
736+
job_model.name,
737+
project=project_id,
738+
location=self.config.location,
739+
credentials=credentials,
740+
)
721741
metadata = self.compute_metadata(job)
722742

723743
logger.info("View the Vertex job at %s", job._dashboard_uri())
@@ -765,14 +785,10 @@ def run_isolated_step(
765785
network=self.config.network,
766786
)
767787

768-
credentials, project_id = self._get_authentication()
769-
client_options = {
770-
"api_endpoint": self.config.location + VERTEX_ENDPOINT_SUFFIX
771-
}
772-
client = aiplatform.gapic.JobServiceClient(
773-
credentials=credentials, client_options=client_options
788+
client = self.get_job_service_client()
789+
parent = (
790+
f"projects/{self.gcp_project_id}/locations/{self.config.location}"
774791
)
775-
parent = f"projects/{project_id}/locations/{self.config.location}"
776792
logger.info(
777793
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
778794
job_request["display_name"],
@@ -781,8 +797,7 @@ def run_isolated_step(
781797
job = client.create_custom_job(parent=parent, custom_job=job_request)
782798
monitor_job(
783799
job_id=job.name,
784-
credentials_source=self,
785-
client_options=client_options,
800+
get_client=self.get_job_service_client,
786801
)
787802

788803
def _upload_and_run_pipeline(
@@ -1060,15 +1075,6 @@ def _configure_container_resources(
10601075

10611076
return dynamic_component
10621077

1063-
def _initialize_vertex_client(self) -> None:
1064-
"""Initializes the Vertex client."""
1065-
credentials, project_id = self._get_authentication()
1066-
aiplatform.init(
1067-
project=project_id,
1068-
location=self.config.location,
1069-
credentials=credentials,
1070-
)
1071-
10721078
def fetch_status(
10731079
self, run: "PipelineRunResponse", include_steps: bool = False
10741080
) -> Tuple[
@@ -1102,8 +1108,6 @@ def fetch_status(
11021108
== run.stack.components[StackComponentType.ORCHESTRATOR][0].id
11031109
)
11041110

1105-
self._initialize_vertex_client()
1106-
11071111
# Fetch the status of the PipelineJob
11081112
if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
11091113
run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
@@ -1115,8 +1119,14 @@ def fetch_status(
11151119
"the status."
11161120
)
11171121

1122+
credentials, project_id = self._get_authentication()
11181123
if run.snapshot and run.snapshot.is_dynamic:
1119-
status = aiplatform.CustomJob.get(run_id).state
1124+
status = aiplatform.CustomJob.get(
1125+
run_id,
1126+
project=project_id,
1127+
location=self.config.location,
1128+
credentials=credentials,
1129+
).state
11201130

11211131
if status in [
11221132
JobState.JOB_STATE_QUEUED,
@@ -1143,7 +1153,12 @@ def fetch_status(
11431153
else:
11441154
pipeline_status = run.status
11451155
else:
1146-
status = aiplatform.PipelineJob.get(run_id).state
1156+
status = aiplatform.PipelineJob.get(
1157+
run_id,
1158+
project=project_id,
1159+
location=self.config.location,
1160+
credentials=credentials,
1161+
).state
11471162

11481163
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
11491164
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#

src/zenml/integrations/gcp/step_operators/vertex_step_operator.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
5757
ZenML entrypoint command in it.
5858
"""
5959

60+
_job_service_client: Optional[aiplatform.gapic.JobServiceClient] = None
61+
6062
def __init__(self, *args: Any, **kwargs: Any) -> None:
6163
"""Initializes the step operator and validates the accelerator type.
6264
@@ -150,6 +152,25 @@ def get_docker_builds(
150152

151153
return builds
152154

155+
def get_job_service_client(self) -> aiplatform.gapic.JobServiceClient:
156+
"""Get the job service client.
157+
158+
Returns:
159+
The job service client.
160+
"""
161+
if self.connector_has_expired():
162+
self._job_service_client = None
163+
164+
if self._job_service_client is None:
165+
credentials, _ = self._get_authentication()
166+
client_options = {
167+
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
168+
}
169+
self._job_service_client = aiplatform.gapic.JobServiceClient(
170+
credentials=credentials, client_options=client_options
171+
)
172+
return self._job_service_client
173+
153174
def launch(
154175
self,
155176
info: "StepRunInfo",
@@ -193,15 +214,10 @@ def launch(
193214
)
194215
logger.debug("Vertex AI Job=%s", job_request)
195216

196-
credentials, project_id = self._get_authentication()
197-
client_options = {
198-
"api_endpoint": self.config.region + VERTEX_ENDPOINT_SUFFIX
199-
}
200-
client = aiplatform.gapic.JobServiceClient(
201-
credentials=credentials, client_options=client_options
217+
client = self.get_job_service_client()
218+
parent = (
219+
f"projects/{self.gcp_project_id}/locations/{self.config.region}"
202220
)
203-
204-
parent = f"projects/{project_id}/locations/{self.config.region}"
205221
logger.info(
206222
"Submitting custom job='%s', path='%s' to Vertex AI Training.",
207223
job_request["display_name"],
@@ -215,6 +231,5 @@ def launch(
215231

216232
monitor_job(
217233
job_id=response.name,
218-
credentials_source=self,
219-
client_options=client_options,
234+
get_client=self.get_job_service_client,
220235
)

src/zenml/integrations/gcp/utils.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Vertex utilities."""
1515

1616
import time
17-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
17+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
1818

1919
from google.api_core.exceptions import ServerError
2020
from google.cloud import aiplatform
@@ -25,9 +25,6 @@
2525
VERTEX_JOB_STATES_COMPLETED,
2626
VERTEX_JOB_STATES_FAILED,
2727
)
28-
from zenml.integrations.gcp.google_credentials_mixin import (
29-
GoogleCredentialsMixin,
30-
)
3128
from zenml.integrations.gcp.vertex_custom_job_parameters import (
3229
VertexCustomJobParameters,
3330
)
@@ -55,59 +52,33 @@ def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None:
5552
)
5653

5754

58-
def get_job_service_client(
59-
credentials_source: GoogleCredentialsMixin,
60-
client_options: Optional[Dict[str, Any]] = None,
61-
) -> aiplatform.gapic.JobServiceClient:
62-
"""Gets a job service client.
63-
64-
Args:
65-
credentials_source: The component that provides the credentials to
66-
access the job.
67-
client_options: The client options to use for the job service client.
68-
69-
Returns:
70-
A job service client.
71-
"""
72-
credentials, _ = credentials_source._get_authentication()
73-
return aiplatform.gapic.JobServiceClient(
74-
credentials=credentials, client_options=client_options
75-
)
76-
77-
7855
def monitor_job(
7956
job_id: str,
80-
credentials_source: GoogleCredentialsMixin,
81-
client_options: Optional[Dict[str, Any]] = None,
57+
get_client: Callable[[], aiplatform.gapic.JobServiceClient],
8258
) -> None:
8359
"""Monitors a job until it is completed.
8460
8561
Args:
8662
job_id: The ID of the job to monitor.
87-
credentials_source: The component that provides the credentials to
88-
access the job.
89-
client_options: The client options to use for the job service client.
63+
get_client: A function that returns an authenticated job service client.
9064
9165
Raises:
9266
RuntimeError: If the job fails.
9367
"""
9468
retry_count = 0
95-
client = get_job_service_client(
96-
credentials_source=credentials_source, client_options=client_options
97-
)
69+
client = get_client()
9870

9971
while True:
10072
time.sleep(POLLING_INTERVAL_IN_SECONDS)
101-
if credentials_source.connector_has_expired():
102-
client = get_job_service_client(
103-
credentials_source=credentials_source,
104-
client_options=client_options,
105-
)
73+
# Fetch a fresh client in case the credentials have expired
74+
client = get_client()
10675

10776
try:
10877
response = client.get_custom_job(name=job_id)
10978
retry_count = 0
11079
except (ConnectionError, ServerError) as err:
80+
# Retry on connection errors, see also
81+
# https://github.com/googleapis/google-api-python-client/issues/218
11182
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
11283
retry_count += 1
11384
logger.warning(

0 commit comments

Comments
 (0)