diff --git a/docs/book/component-guide/model-deployers/vertex.md b/docs/book/component-guide/model-deployers/vertex.md new file mode 100644 index 00000000000..df8fa1c2ff0 --- /dev/null +++ b/docs/book/component-guide/model-deployers/vertex.md @@ -0,0 +1,187 @@ +# Vertex AI Model Deployer + +[Vertex AI](https://cloud.google.com/vertex-ai) provides managed infrastructure for deploying machine learning models at scale. The Vertex AI Model Deployer in ZenML allows you to deploy models to Vertex AI endpoints, providing a scalable and fully managed solution for model serving. + +## When to use it? + +Use the Vertex AI Model Deployer when: + +- You are leveraging Google Cloud Platform (GCP) and wish to integrate with its native ML serving infrastructure. +- You need enterprise-grade model serving capabilities complete with autoscaling and GPU acceleration. +- You require a fully managed solution that abstracts away the operational overhead of serving models. +- You need to deploy models directly from your Vertex AI Model Registry—or even from other registries or artifacts. +- You want seamless integration with GCP services like Cloud Logging, IAM, and VPC. + +This deployer is especially useful for production deployments, high-availability serving, and dynamic scaling based on workloads. + +{% hint style="info" %} +For best results, the Vertex AI Model Deployer works with a Vertex AI Model Registry in your ZenML stack. This allows you to register models with detailed metadata and configuration and then deploy a specific version seamlessly. +{% endhint %} + +## How to deploy it? + +The Vertex AI Model Deployer is enabled via the ZenML GCP integration. First, install the integration: + +```shell +zenml integration install gcp -y +``` + +### Authentication and Service Connector Configuration + +The deployer requires proper GCP authentication. The recommended approach is to use the ZenML Service Connector: + +```shell +# Register the service connector with a service account key +zenml service-connector register vertex_deployer_connector \ + --type gcp \ + --auth-method=service-account \ + --project_id= \ + --service_account_json=@vertex-deployer-sa.json \ + --resource-type gcp-generic + +# Register the model deployer and connect it to the service connector +zenml model-deployer register vertex_deployer \ + --flavor=vertex \ + --location=us-central1 \ + --connector vertex_deployer_connector +``` + +{% hint style="info" %} +The service account used for deployment must have the following permissions: +- `Vertex AI User` to enable model deployments +- `Vertex AI Service Agent` for model endpoint management +- `Storage Object Viewer` if the model artifacts reside in Google Cloud Storage +{% endhint %} + +## How to use it + +A complete usage example is available in the [ZenML Examples repository](https://github.com/zenml-io/zenml-projects/tree/main/vertex-registry-and-deployer). + +### Deploying a Model in a Pipeline + +Below is an example of a deployment step that uses the updated configuration options. In this example, the deployment configuration supports: + +- **Model versioning**: Explicitly provide the model version (using the full resource name from the model registry). +- **Display name and Sync mode**: Fields such as `display_name` (for a friendly endpoint name) and `sync` (to wait for deployment completion) are now available. +- **Traffic configuration**: Route a certain percentage (e.g., 100%) of traffic to this deployment. +- **Advanced options**: You can still specify custom container settings, resource specifications (including GPU options), and explanation configuration via shared classes from `vertex_base_config.py`. + +```python +from typing_extensions import Annotated +from zenml import ArtifactConfig, get_step_context, step +from zenml.client import Client +from zenml.integrations.gcp.services.vertex_deployment import ( + VertexDeploymentConfig, + VertexDeploymentService, +) + +@step(enable_cache=False) +def model_deployer( + model_registry_uri: str, + is_promoted: bool = False, +) -> Annotated[ + VertexDeploymentService, + ArtifactConfig(name="vertex_deployment", is_deployment_artifact=True), +]: + """Model deployer step. + + Args: + model_registry_uri: The full resource name of the model in the registry. + is_promoted: Flag indicating if the model is promoted to production. + + Returns: + The deployed model service. + """ + if not is_promoted: + # Skip deployment if the model is not promoted. + return None + else: + zenml_client = Client() + current_model = get_step_context().model + model_deployer = zenml_client.active_stack.model_deployer + + # Create deployment configuration with advanced options. + vertex_deployment_config = VertexDeploymentConfig( + location="europe-west1", + name=current_model.name, # Unique endpoint name in Vertex AI. + display_name="zenml-vertex-quickstart", + model_name=model_registry_uri, # Fully qualified model name (from model registry). + model_version=current_model.version, # Specify the model version explicitly. + description="An example of deploying a model using the Vertex AI Model Deployer", + sync=True, # Wait for deployment to complete before proceeding. + traffic_percentage=100, # Route 100% of traffic to this model version. + # (Optional) Advanced configurations: + # container=VertexAIContainerSpec( + # image_uri="your-custom-image:latest", + # ports=[8080], + # env={"ENV_VAR": "value"} + # ), + # resources=VertexAIResourceSpec( + # accelerator_type="NVIDIA_TESLA_T4", + # accelerator_count=1, + # machine_type="n1-standard-4", + # min_replica_count=1, + # max_replica_count=3, + # ), + # explanation=VertexAIExplanationSpec( + # metadata={"method": "integrated-gradients"}, + # parameters={"num_integral_steps": 50} + # ) + ) + + service = model_deployer.deploy_model( + config=vertex_deployment_config, + service_type=VertexDeploymentService.SERVICE_TYPE, + ) + + return service +``` + +*Example: [`model_deployer.py`](../../examples/vertex-registry-and-deployer/steps/model_deployer.py)* + +### Configuration Options + +The Vertex AI Model Deployer leverages a comprehensive configuration system defined in the shared base configuration and deployer-specific settings: + +- **Basic Settings:** + - `location`: The GCP region for deployment (e.g., "us-central1" or "europe-west1"). + - `name`: Unique identifier for the deployed endpoint. + - `display_name`: A human-friendly name for the endpoint. + - `model_name`: The fully qualified model name from the model registry. + - `model_version`: The version of the model to deploy. + - `description`: A textual description of the deployment. + - `sync`: A flag to indicate whether the deployment should wait until completion. + - `traffic_percentage`: The percentage of incoming traffic to route to this deployment. + +- **Container and Resource Configuration:** + - Configurations provided via [VertexAIContainerSpec](../../integrations/gcp/flavors/vertex_base_config.py) allow you to specify a custom serving container image, HTTP routes (`predict_route`, `health_route`), environment variables, and port exposure. + - [VertexAIResourceSpec](../../integrations/gcp/flavors/vertex_base_config.py) lets you override the default machine type, number of replicas, and even GPU options. + +- **Advanced Settings:** + - Service account, network configuration, and customer-managed encryption keys. + - Model explanation settings via `VertexAIExplanationSpec` if you need integrated model interpretability. + +These options are defined across the [Vertex AI Base Config](../../integrations/gcp/flavors/vertex_base_config.py) and the deployer–specific configuration in [VertexModelDeployerFlavor](../../integrations/gcp/flavors/vertex_model_deployer_flavor.py). + +### Limitations and Considerations + +1. **Stack Requirements:** + - It is recommended to pair the deployer with a Vertex AI Model Registry in your stack. + - Compatible with both local and remote orchestrators. + - Requires valid GCP credentials and permissions. + +2. **Authentication:** + - Best practice is to use service connectors for secure and managed authentication. + - Supports multiple authentication methods (service accounts, local credentials). + +3. **Costs:** + - Vertex AI endpoints will incur costs based on machine type and uptime. + - Utilize autoscaling (via configured `min_replica_count` and `max_replica_count`) to manage cost. + +4. **Region Consistency:** + - Ensure that the model and deployment are created in the same GCP region. + +For more details, please refer to the [SDK docs](https://sdkdocs.zenml.io) and the relevant implementation files: +- [`vertex_model_deployer.py`](../../integrations/gcp/model_deployers/vertex_model_deployer.py) +- [`vertex_base_config.py`](../../integrations/gcp/flavors/vertex_base_config.py) +- [`vertex_model_deployer_flavor.py`](../../integrations/gcp/flavors/vertex_model_deployer_flavor.py) \ No newline at end of file diff --git a/docs/book/component-guide/model-registries/vertex.md b/docs/book/component-guide/model-registries/vertex.md new file mode 100644 index 00000000000..f4e32ffb514 --- /dev/null +++ b/docs/book/component-guide/model-registries/vertex.md @@ -0,0 +1,207 @@ +# Vertex AI Model Registry + +[Vertex AI](https://cloud.google.com/vertex-ai) is Google Cloud's unified ML platform that helps you build, deploy, and scale ML models. The Vertex AI Model Registry is a centralized repository for managing your ML models throughout their lifecycle. With ZenML's Vertex AI Model Registry integration, you can register model versions—with extended configuration options—track metadata, and seamlessly deploy your models using Vertex AI's managed infrastructure. + +## When would you want to use it? + +You should consider using the Vertex AI Model Registry when: + +- You're already using Google Cloud Platform (GCP) and want to leverage its native ML infrastructure. +- You need enterprise-grade model management with fine-grained access control. +- You want to track model lineage and metadata in a centralized location. +- You're building ML pipelines that integrate with other Vertex AI services. +- You need to deploy models with custom configurations such as defined container images, resource specifications, and additional metadata. + +This registry is particularly useful in scenarios where you: +- Build production ML pipelines that require deployment to Vertex AI endpoints. +- Manage multiple versions of models across development, staging, and production. +- Need to register model versions with detailed configuration for robust deployment. + +{% hint style="warning" %} +**Important:** The Vertex AI Model Registry implementation only supports the model **version** interface—not the model interface. This means that you cannot directly register, update, or delete models; you only have operations for model versions. A model container is automatically created with the first version, and subsequent uploads with the same display name create new versions. +{% endhint %} + +## How do you deploy it? + +The Vertex AI Model Registry flavor is enabled through the ZenML GCP integration. First, install the integration: + +```shell +zenml integration install gcp -y +``` + +### Authentication and Service Connector Configuration + +Vertex AI requires proper GCP authentication. The recommended configuration is via the ZenML Service Connector, which supports both service-account-based authentication and local gcloud credentials. + +1. **Using a GCP Service Connector with a service account (Recommended):** + ```shell + # Register the service connector with a service account key + zenml service-connector register vertex_registry_connector \ + --type gcp \ + --auth-method=service-account \ + --project_id= \ + --service_account_json=@vertex-registry-sa.json \ + --resource-type gcp-generic + + # Register the model registry + zenml model-registry register vertex_registry \ + --flavor=vertex \ + --location=us-central1 + + # Connect the model registry to the service connector + zenml model-registry connect vertex_registry --connector vertex_registry_connector + ``` +2. **Using local gcloud credentials:** + ```shell + # Register the model registry using local gcloud auth + zenml model-registry register vertex_registry \ + --flavor=vertex \ + --location=us-central1 + ``` + +{% hint style="info" %} +The service account needs the following permissions: +- `Vertex AI User` role for creating and managing model versions. +- `Storage Object Viewer` role if accessing models stored in Google Cloud Storage. +{% endhint %} + +## How do you use it? + +### Registering Models inside a Pipeline with Extended Configuration + +The Vertex AI Model Registry supports extended configuration options via the `VertexAIModelConfig` class (defined in the [vertex_base_config.py](../../integrations/gcp/flavors/vertex_base_config.py) file). This means you can specify additional details for your deployments such as: + +- **Container configuration**: Use the `VertexAIContainerSpec` to define a custom serving container (e.g., specifying the `image_uri`, `predict_route`, `health_route`, and exposed ports). +- **Resource configuration**: Use the `VertexAIResourceSpec` to specify compute resources like `machine_type`, `min_replica_count`, and `max_replica_count`. +- **Additional metadata and labels**: Annotate your model registrations with pipeline details, stage information, and custom labels. + +Below is an example of how you might register a model version in your ZenML pipeline: + +```python +from typing_extensions import Annotated + +from zenml import ArtifactConfig, get_step_context, step +from zenml.client import Client +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIContainerSpec, + VertexAIModelConfig, + VertexAIResourceSpec, +) +from zenml.logger import get_logger +from zenml.model_registries.base_model_registry import ( + ModelRegistryModelMetadata, +) + +logger = get_logger(__name__) + + +@step(enable_cache=False) +def model_register( + is_promoted: bool = False, +) -> Annotated[str, ArtifactConfig(name="model_registry_uri")]: + """Model registration step. + + Registers a model version in the Vertex AI Model Registry with extended configuration + and returns the full resource name of the registered model. + + Extended configuration includes settings for container, resources, and metadata which can then be reused in + subsequent model deployments. + """ + if is_promoted: + # Get the current model from the step context + current_model = get_step_context().model + + client = Client() + model_registry = client.active_stack.model_registry + # Create an extended model configuration using Vertex AI base settings + model_config = VertexAIModelConfig( + location="europe-west1", + container=VertexAIContainerSpec( + image_uri="europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest", + predict_route="predict", + health_route="health", + ports=[8080], + ), + resources=VertexAIResourceSpec( + machine_type="n1-standard-4", + min_replica_count=1, + max_replica_count=1, + ), + labels={"env": "production"}, + description="Extended model configuration for Vertex AI", + ) + + # Register the model version with the extended configuration as metadata + model_version = model_registry.register_model_version( + name=current_model.name, + version=str(current_model.version), + model_source_uri=current_model.get_model_artifact("sklearn_classifier").uri, + description="ZenML model version registered with extended configuration", + metadata=ModelRegistryModelMetadata( + zenml_pipeline_name=get_step_context().pipeline.name, + zenml_pipeline_run_uuid=str(get_step_context().pipeline_run.id), + zenml_step_name=get_step_context().step_run.name, + ), + config=model_config, + ) + logger.info(f"Model version {model_version.version} registered in Model Registry") + + # Return the full resource name of the registered model + return model_version.registered_model.name + else: + return "" +``` + +*Example: [`model_register.py`](../../examples/vertex-registry-and-deployer/steps/model_register.py)* + +### Working with Model Versions + +Since the Vertex AI Model Registry supports only version-level operations, here are some commands to manage model versions: + +```shell +# List all model versions +zenml model-registry models list-versions + +# Get details of a specific model version +zenml model-registry models get-version -v + +# Delete a model version +zenml model-registry models delete-version -v +``` + +### Configuration Options + +The Vertex AI Model Registry accepts several configuration options, now enriched with extended settings: + +- **location**: The GCP region where your resources will be created (e.g., "us-central1" or "europe-west1"). +- **project_id**: (Optional) A GCP project ID override. +- **credentials**: (Optional) GCP credentials configuration. +- **container**: (Optional) Detailed container settings (defined via `VertexAIContainerSpec`) for the model's serving container such as: + - `image_uri` + - `predict_route` + - `health_route` + - `ports` +- **resources**: (Optional) Compute resource settings (using `VertexAIResourceSpec`) like `machine_type`, `min_replica_count`, and `max_replica_count`. +- **labels** and **metadata**: Additional annotation data for organizing and tracking your model versions. + +These configuration options are specified in the [Vertex AI Base Config](../../integrations/gcp/flavors/vertex_base_config.py) and further extended in the [Vertex AI Model Registry Flavor](../../integrations/gcp/flavors/vertex_model_registry_flavor.py). + +### Key Differences from Other Model Registries + +1. **Version-Only Interface**: Vertex AI only supports version-level operations for model registration. +2. **Authentication**: Uses GCP service connectors and local credentials integrated via ZenML. +3. **Extended Configuration**: Register model versions with detailed settings for container, resources, and metadata through `VertexAIModelConfig`. +4. **Managed Service**: As a fully managed service, Vertex AI handles infrastructure management while you focus on your ML models. + +## Limitations + +- The methods `register_model()`, `update_model()`, and `delete_model()` are not implemented; you can only work with model versions. +- It is recommended to specify a serving container image URI rather than rely on the default scikit-learn container to ensure compatibility with Vertex AI endpoints. +- All models registered through this integration are automatically labeled with `managed_by="zenml"` for consistent tracking. + +For more detailed information, check out the [SDK docs](https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-gcp/#zenml.integrations.gcp.model_registry). + +
+ ZenML Scarf +
ZenML in action
+
\ No newline at end of file diff --git a/docs/book/toc.md b/docs/book/toc.md index ec51f95fc4b..e74dc579755 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -278,6 +278,7 @@ * [Develop a custom experiment tracker](component-guide/experiment-trackers/custom.md) * [Model Deployers](component-guide/model-deployers/model-deployers.md) * [MLflow](component-guide/model-deployers/mlflow.md) + * [VertexAI](component-guide/model-deployers/vertex.md) * [Seldon](component-guide/model-deployers/seldon.md) * [BentoML](component-guide/model-deployers/bentoml.md) * [Hugging Face](component-guide/model-deployers/huggingface.md) @@ -310,6 +311,7 @@ * [Develop a Custom Annotator](component-guide/annotators/custom.md) * [Model Registries](component-guide/model-registries/model-registries.md) * [MLflow Model Registry](component-guide/model-registries/mlflow.md) + * [VertexAI](component-guide/model-registries/vertex.md) * [Develop a Custom Model Registry](component-guide/model-registries/custom.md) * [Feature Stores](component-guide/feature-stores/feature-stores.md) * [Feast](component-guide/feature-stores/feast.md) diff --git a/src/zenml/cli/model_registry.py b/src/zenml/cli/model_registry.py index c326dfcd9a3..fe99601a01c 100644 --- a/src/zenml/cli/model_registry.py +++ b/src/zenml/cli/model_registry.py @@ -18,6 +18,7 @@ import click +from zenml import __version__ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli from zenml.enums import StackComponentType @@ -642,7 +643,7 @@ def register_model_version( # Parse metadata metadata = dict(metadata) if metadata else {} registered_metadata = ModelRegistryModelMetadata(**dict(metadata)) - registered_metadata.zenml_version = zenml_version + registered_metadata.zenml_version = zenml_version or __version__ registered_metadata.zenml_run_name = zenml_run_name registered_metadata.zenml_pipeline_name = zenml_pipeline_name registered_metadata.zenml_step_name = zenml_step_name diff --git a/src/zenml/integrations/gcp/__init__.py b/src/zenml/integrations/gcp/__init__.py index 231d3b9f62e..f4163d71e8f 100644 --- a/src/zenml/integrations/gcp/__init__.py +++ b/src/zenml/integrations/gcp/__init__.py @@ -34,6 +34,11 @@ GCP_VERTEX_ORCHESTRATOR_FLAVOR = "vertex" GCP_VERTEX_STEP_OPERATOR_FLAVOR = "vertex" +# Model deployer constants +VERTEX_MODEL_REGISTRY_FLAVOR = "vertex" +VERTEX_MODEL_DEPLOYER_FLAVOR = "vertex" +VERTEX_SERVICE_ARTIFACT = "vertex_deployment_service" + # Service connector constants GCP_CONNECTOR_TYPE = "gcp" GCP_RESOURCE_TYPE = "gcp-generic" @@ -74,6 +79,8 @@ def flavors(cls) -> List[Type[Flavor]]: VertexExperimentTrackerFlavor, VertexOrchestratorFlavor, VertexStepOperatorFlavor, + VertexModelDeployerFlavor, + VertexModelRegistryFlavor, ) return [ @@ -82,6 +89,8 @@ def flavors(cls) -> List[Type[Flavor]]: VertexExperimentTrackerFlavor, VertexOrchestratorFlavor, VertexStepOperatorFlavor, + VertexModelRegistryFlavor, + VertexModelDeployerFlavor, ] diff --git a/src/zenml/integrations/gcp/flavors/__init__.py b/src/zenml/integrations/gcp/flavors/__init__.py index e70f4937594..b78b574a80e 100644 --- a/src/zenml/integrations/gcp/flavors/__init__.py +++ b/src/zenml/integrations/gcp/flavors/__init__.py @@ -33,6 +33,14 @@ VertexStepOperatorConfig, VertexStepOperatorFlavor, ) +from zenml.integrations.gcp.flavors.vertex_model_deployer_flavor import ( + VertexModelDeployerConfig, + VertexModelDeployerFlavor, +) +from zenml.integrations.gcp.flavors.vertex_model_registry_flavor import ( + VertexAIModelRegistryConfig, + VertexModelRegistryFlavor, +) __all__ = [ "GCPArtifactStoreFlavor", @@ -45,4 +53,8 @@ "VertexOrchestratorConfig", "VertexStepOperatorFlavor", "VertexStepOperatorConfig", + "VertexModelDeployerFlavor", + "VertexModelDeployerConfig", + "VertexModelRegistryFlavor", + "VertexAIModelRegistryConfig", ] diff --git a/src/zenml/integrations/gcp/flavors/vertex_base_config.py b/src/zenml/integrations/gcp/flavors/vertex_base_config.py new file mode 100644 index 00000000000..e2872411ba6 --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/vertex_base_config.py @@ -0,0 +1,199 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Shared configuration classes for Vertex AI components.""" + +from typing import Any, Dict, Optional, Sequence + +from pydantic import BaseModel, Field + +from zenml.config.base_settings import BaseSettings + + +class VertexAIContainerSpec(BaseModel): + """Container specification for Vertex AI models and endpoints.""" + + image_uri: Optional[str] = Field( + None, description="Docker image URI for model serving" + ) + command: Optional[Sequence[str]] = Field( + None, description="Container command to run" + ) + args: Optional[Sequence[str]] = Field( + None, description="Container command arguments" + ) + env: Optional[Dict[str, str]] = Field( + None, description="Environment variables" + ) + ports: Optional[Sequence[int]] = Field( + None, description="Container ports to expose" + ) + predict_route: Optional[str] = Field( + None, description="HTTP path for prediction requests" + ) + health_route: Optional[str] = Field( + None, description="HTTP path for health check requests" + ) + + +class VertexAIResourceSpec(BaseModel): + """Resource specification for Vertex AI deployments.""" + + machine_type: Optional[str] = Field( + None, description="Compute instance machine type" + ) + accelerator_type: Optional[str] = Field( + None, description="Hardware accelerator type" + ) + accelerator_count: Optional[int] = Field( + None, description="Number of accelerators" + ) + min_replica_count: Optional[int] = Field( + 1, description="Minimum number of replicas" + ) + max_replica_count: Optional[int] = Field( + 1, description="Maximum number of replicas" + ) + + +class VertexAIExplanationSpec(BaseModel): + """Explanation configuration for Vertex AI models.""" + + metadata: Optional[Dict[str, Any]] = Field( + None, description="Explanation metadata" + ) + parameters: Optional[Dict[str, Any]] = Field( + None, description="Explanation parameters" + ) + + +class VertexAIBaseConfig(BaseModel): + """Base configuration shared by Vertex AI components. + + Reference: + - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models + - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints + """ + + # Basic settings + location: str = Field( + "us-central1", description="GCP region for Vertex AI resources" + ) + project_id: Optional[str] = Field( + None, description="Optional project ID override" + ) + + # Container configuration + container: Optional[VertexAIContainerSpec] = Field( + None, description="Container configuration" + ) + + # Resource configuration + resources: Optional[VertexAIResourceSpec] = Field( + None, description="Resource configuration" + ) + + # Service configuration + service_account: Optional[str] = Field( + None, description="Service account email" + ) + network: Optional[str] = Field(None, description="VPC network") + + # Security + encryption_spec_key_name: Optional[str] = Field( + None, description="Customer-managed encryption key" + ) + + # Monitoring and logging + enable_access_logging: Optional[bool] = Field( + None, description="Enable access logging" + ) + disable_container_logging: Optional[bool] = Field( + None, description="Disable container logging" + ) + + # Model explanation + explanation: Optional[VertexAIExplanationSpec] = Field( + None, description="Model explanation configuration" + ) + + # Labels and metadata + labels: Optional[Dict[str, str]] = Field( + None, description="Resource labels" + ) + metadata: Optional[Dict[str, str]] = Field( + None, description="Custom metadata" + ) + + +class VertexAIModelConfig(VertexAIBaseConfig): + """Configuration specific to Vertex AI Models.""" + + # Model metadata + display_name: Optional[str] = None + description: Optional[str] = None + version_description: Optional[str] = None + version_aliases: Optional[Sequence[str]] = None + + # Model artifacts + artifact_uri: Optional[str] = None + model_source_spec: Optional[Dict[str, Any]] = None + + # Model versioning + is_default_version: Optional[bool] = None + + # Model formats + supported_deployment_resources_types: Optional[Sequence[str]] = None + supported_input_storage_formats: Optional[Sequence[str]] = None + supported_output_storage_formats: Optional[Sequence[str]] = None + + # Training metadata + training_pipeline_display_name: Optional[str] = None + training_pipeline_id: Optional[str] = None + + # Model optimization + model_source_info: Optional[Dict[str, str]] = None + original_model_info: Optional[Dict[str, str]] = None + containerized_model_optimization: Optional[Dict[str, Any]] = None + + +class VertexAIEndpointConfig(VertexAIBaseConfig): + """Configuration specific to Vertex AI Endpoints.""" + + # Endpoint metadata + display_name: Optional[str] = None + description: Optional[str] = None + + # Traffic configuration + traffic_split: Optional[Dict[str, int]] = None + traffic_percentage: Optional[int] = 0 + + # Autoscaling + autoscaling_target_cpu_utilization: Optional[float] = None + autoscaling_target_accelerator_duty_cycle: Optional[float] = None + + # Deployment + sync: Optional[bool] = False + deploy_request_timeout: Optional[int] = None + existing_endpoint: Optional[str] = None + + +class VertexAIBaseSettings(BaseSettings): + """Base settings for Vertex AI components.""" + + location: str = Field( + "us-central1", description="Default GCP region for Vertex AI resources" + ) + project_id: Optional[str] = Field( + None, description="Optional project ID override" + ) diff --git a/src/zenml/integrations/gcp/flavors/vertex_model_deployer_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_model_deployer_flavor.py new file mode 100644 index 00000000000..7c450f51b09 --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/vertex_model_deployer_flavor.py @@ -0,0 +1,132 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Vertex AI model deployer flavor.""" + +from typing import TYPE_CHECKING, Optional, Type + +from zenml.integrations.gcp import ( + GCP_RESOURCE_TYPE, + VERTEX_MODEL_DEPLOYER_FLAVOR, +) +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIEndpointConfig, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsConfigMixin, +) +from zenml.model_deployers.base_model_deployer import ( + BaseModelDeployerConfig, + BaseModelDeployerFlavor, +) +from zenml.models.v2.misc.service_connector_type import ( + ServiceConnectorRequirements, +) + +if TYPE_CHECKING: + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + +class VertexModelDeployerConfig( + BaseModelDeployerConfig, + GoogleCredentialsConfigMixin, + VertexAIEndpointConfig, +): + """Configuration for the Vertex AI model deployer. + + This configuration combines: + - Base model deployer configuration + - Google Cloud authentication + - Vertex AI endpoint configuration + """ + + +class VertexModelDeployerFlavor(BaseModelDeployerFlavor): + """Vertex AI model deployer flavor.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return VERTEX_MODEL_DEPLOYER_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + resource_type=GCP_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png" + + @property + def config_class(self) -> Type[VertexModelDeployerConfig]: + """Returns `VertexModelDeployerConfig` config class. + + Returns: + The config class. + """ + return VertexModelDeployerConfig + + @property + def implementation_class(self) -> Type["VertexModelDeployer"]: + """Implementation class for this flavor. + + Returns: + The implementation class. + """ + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + return VertexModelDeployer diff --git a/src/zenml/integrations/gcp/flavors/vertex_model_registry_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_model_registry_flavor.py new file mode 100644 index 00000000000..8524d407e2f --- /dev/null +++ b/src/zenml/integrations/gcp/flavors/vertex_model_registry_flavor.py @@ -0,0 +1,130 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""VertexAI model registry flavor.""" + +from typing import TYPE_CHECKING, Optional, Type + +from zenml.integrations.gcp import ( + GCP_RESOURCE_TYPE, + VERTEX_MODEL_REGISTRY_FLAVOR, +) +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIModelConfig, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsConfigMixin, +) +from zenml.model_registries.base_model_registry import ( + BaseModelRegistryConfig, + BaseModelRegistryFlavor, +) +from zenml.models import ServiceConnectorRequirements + +if TYPE_CHECKING: + from zenml.integrations.gcp.model_registries import ( + VertexAIModelRegistry, + ) + + +class VertexAIModelRegistryConfig( + BaseModelRegistryConfig, + GoogleCredentialsConfigMixin, + VertexAIModelConfig, +): + """Configuration for the VertexAI model registry. + + This configuration combines: + - Base model registry configuration + - Google Cloud authentication + - Vertex AI model configuration + """ + + +class VertexModelRegistryFlavor(BaseModelRegistryFlavor): + """Model registry flavor for VertexAI models.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return VERTEX_MODEL_REGISTRY_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + resource_type=GCP_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png" + + @property + def config_class(self) -> Type[VertexAIModelRegistryConfig]: + """Returns `VertexAIModelRegistryConfig` config class. + + Returns: + The config class. + """ + return VertexAIModelRegistryConfig + + @property + def implementation_class(self) -> Type["VertexAIModelRegistry"]: + """Implementation class for this flavor. + + Returns: + The implementation class. + """ + from zenml.integrations.gcp.model_registries import ( + VertexAIModelRegistry, + ) + + return VertexAIModelRegistry diff --git a/src/zenml/integrations/gcp/model_deployers/__init__.py b/src/zenml/integrations/gcp/model_deployers/__init__.py new file mode 100644 index 00000000000..203f57c096f --- /dev/null +++ b/src/zenml/integrations/gcp/model_deployers/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the Vertex AI model deployers.""" + +from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( # noqa + VertexModelDeployer, +) + +__all__ = ["VertexModelDeployer"] diff --git a/src/zenml/integrations/gcp/model_deployers/vertex_model_deployer.py b/src/zenml/integrations/gcp/model_deployers/vertex_model_deployer.py new file mode 100644 index 00000000000..0256dde9368 --- /dev/null +++ b/src/zenml/integrations/gcp/model_deployers/vertex_model_deployer.py @@ -0,0 +1,260 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Vertex AI Model Deployer.""" + +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, cast +from uuid import UUID + +from google.cloud import aiplatform + +from zenml.analytics.enums import AnalyticsEvent +from zenml.analytics.utils import track_handler +from zenml.client import Client +from zenml.enums import StackComponentType +from zenml.integrations.gcp.flavors.vertex_model_deployer_flavor import ( + VertexModelDeployerConfig, + VertexModelDeployerFlavor, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsMixin, +) +from zenml.integrations.gcp.model_registries.vertex_model_registry import ( + VertexAIModelRegistry, +) +from zenml.integrations.gcp.services.vertex_deployment import ( + VertexDeploymentConfig, + VertexDeploymentService, +) +from zenml.logger import get_logger +from zenml.model_deployers import BaseModelDeployer +from zenml.model_deployers.base_model_deployer import ( + DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + BaseModelDeployerFlavor, +) +from zenml.services import BaseService, ServiceConfig +from zenml.stack.stack import Stack +from zenml.stack.stack_validator import StackValidator + +logger = get_logger(__name__) + + +class VertexModelDeployer(BaseModelDeployer, GoogleCredentialsMixin): + """Vertex AI endpoint model deployer.""" + + NAME: ClassVar[str] = "Vertex AI" + FLAVOR: ClassVar[Type["BaseModelDeployerFlavor"]] = ( + VertexModelDeployerFlavor + ) + + @property + def config(self) -> VertexModelDeployerConfig: + """Returns the `VertexModelDeployerConfig` config. + + Returns: + The configuration. + """ + return cast(VertexModelDeployerConfig, self._config) + + def _init_vertex_client( + self, + credentials: Optional[Any] = None, + ) -> None: + """Initialize Vertex AI client with proper credentials. + + Args: + credentials: Optional credentials to use + """ + if not credentials: + credentials, project_id = self._get_authentication() + + # Initialize with per-instance credentials + aiplatform.init( + project=project_id, + location=self.config.location, + credentials=credentials, + ) + + @property + def validator(self) -> Optional[StackValidator]: + """Validates that the stack contains a Vertex AI model registry. + + Returns: + A StackValidator instance. + """ + + def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]: + """Validates stack requirements. + + Args: + stack: The stack to validate. + + Returns: + A tuple of (is_valid, error_message). + """ + model_registry = stack.model_registry + if not isinstance(model_registry, VertexAIModelRegistry): + return False, ( + "The Vertex AI model deployer requires a Vertex AI model " + "registry to be present in the stack. Please add a Vertex AI " + "model registry to the stack." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.MODEL_REGISTRY, + }, + custom_validation_function=_validate_stack_requirements, + ) + + def _create_deployment_service( + self, id: UUID, timeout: int, config: VertexDeploymentConfig + ) -> VertexDeploymentService: + """Creates a new VertexAIDeploymentService. + + Args: + id: the UUID of the model to be deployed + timeout: timeout in seconds for deployment operations + config: deployment configuration + + Returns: + The VertexDeploymentService instance + """ + # Initialize client with fresh credentials + self._init_vertex_client() + + # Create service instance + service = VertexDeploymentService(uuid=id, config=config) + logger.info("Creating Vertex AI deployment service with ID %s", id) + + # Start the service + service.start(timeout=timeout) + return service + + def perform_deploy_model( + self, + id: UUID, + config: ServiceConfig, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Deploy a model to Vertex AI. + + Args: + id: the UUID of the service to be created + config: deployment configuration + timeout: timeout for deployment operations + + Returns: + The deployment service instance + """ + with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler: + config = cast(VertexDeploymentConfig, config) + + # Create and start deployment service + service = self._create_deployment_service( + id=id, config=config, timeout=timeout + ) + + # Track analytics + client = Client() + stack = client.active_stack + stack_metadata = { + component_type.value: component.flavor + for component_type, component in stack.components.items() + } + analytics_handler.metadata = { + "store_type": client.zen_store.type.value, + **stack_metadata, + } + + return service + + def perform_stop_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> BaseService: + """Stop a Vertex AI deployment service. + + Args: + service: The service to stop + timeout: Timeout for stop operation + force: Whether to force stop + + Returns: + The stopped service + """ + # Initialize client with fresh credentials + self._init_vertex_client() + + service.stop(timeout=timeout, force=force) + return service + + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Start a Vertex AI deployment service. + + Args: + service: The service to start + timeout: Timeout for start operation + + Returns: + The started service + """ + # Initialize client with fresh credentials + self._init_vertex_client() + + service.start(timeout=timeout) + return service + + def perform_delete_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> None: + """Delete a Vertex AI deployment service. + + Args: + service: The service to delete + timeout: Timeout for delete operation + force: Whether to force delete + """ + # Initialize client with fresh credentials + self._init_vertex_client() + + service = cast(VertexDeploymentService, service) + service.stop(timeout=timeout, force=force) + + @staticmethod + def get_model_server_info( # type: ignore[override] + service_instance: "VertexDeploymentService", + ) -> Dict[str, Optional[str]]: + """Get information about the deployed model server. + + Args: + service_instance: The deployment service instance + + Returns: + Dict containing server information + """ + return { + "prediction_url": service_instance.get_prediction_url(), + "status": service_instance.status.state.value, + } diff --git a/src/zenml/integrations/gcp/model_registries/__init__.py b/src/zenml/integrations/gcp/model_registries/__init__.py new file mode 100644 index 00000000000..672c7c19619 --- /dev/null +++ b/src/zenml/integrations/gcp/model_registries/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the Vertex AI model registry.""" + +from zenml.integrations.gcp.model_registries.vertex_model_registry import ( + VertexAIModelRegistry +) + +__all__ = ["VertexAIModelRegistry"] diff --git a/src/zenml/integrations/gcp/model_registries/vertex_model_registry.py b/src/zenml/integrations/gcp/model_registries/vertex_model_registry.py new file mode 100644 index 00000000000..3fa8eb25e2b --- /dev/null +++ b/src/zenml/integrations/gcp/model_registries/vertex_model_registry.py @@ -0,0 +1,673 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Vertex AI model registry integration for ZenML.""" + +import base64 +import re +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, cast + +from google.cloud import aiplatform + +from zenml.client import Client +from zenml.integrations.gcp.flavors.vertex_model_registry_flavor import ( + VertexAIModelRegistryConfig, +) +from zenml.integrations.gcp.google_credentials_mixin import ( + GoogleCredentialsMixin, +) +from zenml.logger import get_logger +from zenml.model_registries.base_model_registry import ( + BaseModelRegistry, + ModelRegistryModelMetadata, + ModelVersionStage, + RegisteredModel, + RegistryModelVersion, +) + +logger = get_logger(__name__) + +# Constants for Vertex AI limitations +MAX_LABEL_COUNT = 64 +MAX_LABEL_KEY_LENGTH = 63 +MAX_LABEL_VALUE_LENGTH = 63 +MAX_DISPLAY_NAME_LENGTH = 128 + + +class VertexAIModelRegistry(BaseModelRegistry, GoogleCredentialsMixin): + """Register models using Vertex AI.""" + + @property + def config(self) -> VertexAIModelRegistryConfig: + """Returns the config of the model registry. + + Returns: + The configuration. + """ + return cast(VertexAIModelRegistryConfig, self._config) + + def _sanitize_label(self, value: str) -> str: + """Sanitize a label value to comply with Vertex AI requirements. + + Args: + value: The label value to sanitize + + Returns: + Sanitized label value + """ + if not value: + return "" + + # Convert to lowercase + value = value.lower() + + # Replace any character that's not lowercase letter, number, dash or underscore + value = re.sub(r"[^a-z0-9\-_]", "-", value) + + # Ensure it starts with a letter/number by prepending 'x' if needed + if not value[0].isalnum(): + value = f"x{value}" + + # Truncate to 63 chars to stay under limit + return value[:63] + + def _get_deployer_id(self) -> str: + """Get the current ZenML server/deployer ID for multi-tenancy support. + + Returns: + The deployer ID string + """ + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + client = Client() + model_deployer = client.active_stack.model_deployer + if not isinstance(model_deployer, VertexModelDeployer): + raise ValueError("VertexModelDeployer is not active in the stack.") + return str(model_deployer.id) + + def _encode_name_version(self, name: str, version: str) -> str: + """Encode model name and version into a Vertex AI compatible format. + + Args: + name: Model name + version: Model version + + Returns: + Encoded string suitable for Vertex AI + """ + # Base64 encode to handle special characters while preserving uniqueness + encoded = base64.b64encode(f"{name}:{version}".encode()).decode() + # Make it URL and label safe + encoded = encoded.replace("+", "-").replace("/", "_").replace("=", "") + return encoded[:MAX_DISPLAY_NAME_LENGTH] + + def _decode_name_version(self, encoded: str) -> Tuple[str, str]: + """Decode model name and version from encoded format. + + Args: + encoded: The encoded string + + Returns: + Tuple of (name, version) + """ + # Add back padding + padding = 4 - (len(encoded) % 4) + if padding != 4: + encoded += "=" * padding + # Restore special chars + encoded = encoded.replace("-", "+").replace("_", "/") + try: + decoded = base64.b64decode(encoded).decode() + name, version = decoded.split(":", 1) + return name, version + except Exception as e: + logger.warning( + f"Failed to decode name/version from {encoded}: {e}" + ) + return encoded, "unknown" + + def _prepare_labels( + self, + metadata: Optional[Dict[str, str]] = None, + stage: Optional[ModelVersionStage] = None, + ) -> Dict[str, str]: + """Prepare labels for Vertex AI model. + + Args: + metadata: Optional metadata to include as labels + stage: Optional model version stage + + Returns: + Dictionary of sanitized labels + """ + labels = {} + + # Add base labels + labels["managed_by"] = "zenml" + labels["deployer_id"] = self._sanitize_label(self._get_deployer_id()) + + # Add stage if provided + if stage: + labels["stage"] = self._sanitize_label(stage.value) + + # Process metadata if provided + if metadata: + # If metadata is not a dict (e.g. a pydantic model), convert it using .dict() + if not isinstance(metadata, dict): + try: + metadata = metadata.dict() + except Exception as e: + logger.warning(f"Unable to convert metadata to dict: {e}") + metadata = {} + for key, value in metadata.items(): + # Skip None values + if value is None: + continue + # Convert complex objects to string + if isinstance(value, (dict, list)): + value = ( + "x" # Simplify complex objects to avoid length issues + ) + # Sanitize both key and value + sanitized_key = self._sanitize_label(str(key)) + sanitized_value = self._sanitize_label(str(value)) + # Only add if both key and value are valid + if sanitized_key and sanitized_value: + labels[sanitized_key] = sanitized_value + + # Ensure we don't exceed 64 labels + if len(labels) > 64: + # Keep essential labels and truncate the rest + essential_labels = { + k: labels[k] + for k in ["managed_by", "deployer_id", "stage"] + if k in labels + } + # Add remaining labels up to limit + remaining_slots = 64 - len(essential_labels) + other_labels = { + k: v + for i, (k, v) in enumerate(labels.items()) + if k not in essential_labels and i < remaining_slots + } + labels = {**essential_labels, **other_labels} + + return labels + + def _get_model_id(self, name: str) -> str: + """Get the full Vertex AI model ID. + + Args: + name: Model name + + Returns: + Full model ID in format: projects/{project}/locations/{location}/models/{model} + """ + _, project_id = self._get_authentication() + return f"projects/{project_id}/locations/{self.config.location}/models/{name}" + + def _get_model_version_id(self, model_id: str, version: str) -> str: + """Get the full Vertex AI model version ID. + + Args: + model_id: Full model ID + version: Version string + + Returns: + Full model version ID in format: {model_id}/versions/{version} + """ + return f"{model_id}/versions/{version}" + + def _init_vertex_model( + self, name: str, version: Optional[str] = None + ) -> Optional[aiplatform.Model]: + """Initialize a single Vertex AI model with proper credentials. + + This method returns one Vertex AI model based on the given name (and optional version). + + Args: + name: The model name. + version: The model version (optional). + + Returns: + A single Vertex AI model instance or None if initialization fails. + """ + credentials, project_id = self._get_authentication() + location = self.config.location + kwargs = { + "location": location, + "project": project_id, + "credentials": credentials, + } + + if name.startswith("projects/"): + kwargs["model_name"] = name + else: + # Attempt to find an existing model by display_name + existing_models = aiplatform.Model.list( + filter=f"display_name={name}", + project=self.config.project_id or project_id, + location=location, + ) + if existing_models: + kwargs["model_name"] = existing_models[0].resource_name + else: + model_id = self._get_model_id(name) + if version: + model_id = self._get_model_version_id(model_id, version) + kwargs["model_name"] = model_id + + try: + return aiplatform.Model(**kwargs) + except Exception as e: + logger.warning(f"Failed to initialize model: {e}") + return None + + def register_model( + self, + name: str, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> RegisteredModel: + """Register a model to the Vertex AI model registry.""" + raise NotImplementedError( + "Vertex AI does not support registering models, you can only register model versions, skipping model registration..." + ) + + def delete_model( + self, + name: str, + ) -> None: + """Delete a model and all of its versions from the Vertex AI model registry.""" + try: + model = self._init_vertex_model(name=name) + if isinstance(model, aiplatform.Model): + model.delete() + logger.info(f"Deleted model '{name}' and all its versions.") + except Exception as e: + raise RuntimeError(f"Failed to delete model: {str(e)}") + + def update_model( + self, + name: str, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + remove_metadata: Optional[List[str]] = None, + ) -> RegisteredModel: + """Update a model in the Vertex AI model registry.""" + raise NotImplementedError( + "Vertex AI does not support updating models, you can only update model versions, skipping model registration..." + ) + + def get_model(self, name: str) -> RegisteredModel: + """Get a model from the Vertex AI model registry by name without needing a version.""" + try: + # Fetch by display_name, and use unique labels to ensure multi-tenancy + model = aiplatform.Model(display_name=name) + return RegisteredModel( + name=model.display_name, + description=model.description, + metadata=model.labels, + ) + except Exception as e: + raise RuntimeError(f"Failed to get model: {str(e)}") + + def list_models( + self, + name: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> List[RegisteredModel]: + """List models in the Vertex AI model registry.""" + credentials, project_id = self._get_authentication() + location = self.config.location + # Always filter with ZenML-specific labels (including deployer id for multi-tenancy) + filter_expr = "labels.managed_by=zenml" + + if name: + filter_expr += f" AND display_name={name}" + if metadata: + for key, value in metadata.items(): + filter_expr += f" AND labels.{key}={value}" + try: + all_models = aiplatform.Model.list( + project=project_id, + location=location, + filter=filter_expr, + credentials=credentials, + ) + # Deduplicate by display_name so only one entry per "logical" model is returned. + unique_models = {model.display_name: model for model in all_models} + return [ + RegisteredModel( + name=parent_model.display_name, + description=parent_model.description, + metadata=parent_model.labels, + ) + for parent_model in unique_models.values() + ] + except Exception as e: + raise RuntimeError(f"Failed to list models: {str(e)}") + + def register_model_version( + self, + name: str, + version: Optional[str] = None, + model_source_uri: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[ModelRegistryModelMetadata] = None, + **kwargs: Any, + ) -> RegistryModelVersion: + """Register a model version to the Vertex AI model registry. + + Args: + name: Model name + version: Model version + model_source_uri: URI to model artifacts + description: Model description + metadata: Model metadata (expected to be a ModelRegistryModelMetadata or + equivalent serializable dict) + **kwargs: Additional arguments + + Returns: + RegistryModelVersion instance + """ + # Prepare labels with internal ZenML metadata, ensuring they are sanitized + metadata_dict = metadata.model_dump() if metadata else {} + labels = self._prepare_labels(metadata_dict) + if version: + labels["user_version"] = self._sanitize_label(version) + + # Get the container image from the config if available, otherwise fallback to metadata + if ( + hasattr(self.config, "container") + and self.config.container + and self.config.container.image_uri + ): + serving_container_image_uri = self.config.container.image_uri + else: + serving_container_image_uri = metadata_dict.get( + "serving_container_image_uri", + "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest", + ) + + # Use a consistently sanitized display name instead of the raw model name + model_display_name = self._sanitize_model_display_name(name) + + # Build extended upload arguments for vertex.Model.upload, + # leveraging extra settings from self.config. + upload_arguments = { + "serving_container_image_uri": serving_container_image_uri, + "artifact_uri": model_source_uri or self.config.artifact_uri, + "is_default_version": self.config.is_default_version + if self.config.is_default_version is not None + else True, + "version_aliases": self.config.version_aliases, + "version_description": self.config.version_description, + "serving_container_predict_route": self.config.container.predict_route + if self.config.container + else None, + "serving_container_health_route": self.config.container.health_route + if self.config.container + else None, + "description": description or self.config.description, + "serving_container_command": self.config.container.command + if self.config.container + else None, + "serving_container_args": self.config.container.args + if self.config.container + else None, + "serving_container_environment_variables": self.config.container.env + if self.config.container + else None, + "serving_container_ports": self.config.container.ports + if self.config.container + else None, + "display_name": self.config.display_name or model_display_name, + "project": self.config.project_id, + "location": self.config.location, + "labels": labels, + "encryption_spec_key_name": self.config.encryption_spec_key_name, + } + + # Include explanation settings if provided in the config. + if self.config.explanation: + upload_arguments["explanation_metadata"] = ( + self.config.explanation.metadata + ) + upload_arguments["explanation_parameters"] = ( + self.config.explanation.parameters + ) + + # Remove any parameters that are None to avoid passing them to upload. + upload_arguments = { + k: v for k, v in upload_arguments.items() if v is not None + } + + parent_model = self._init_vertex_model(name=name, version=version) + assert isinstance(parent_model, aiplatform.Model) + if parent_model and parent_model.uri == model_source_uri: + logger.info( + f"Model version {version} already exists, skipping upload..." + ) + return self._vertex_model_to_registry_version(parent_model) + # Always call model.upload (even if a parent model already exists), since Vertex AI + # expects a full upload for each version. + upload_arguments["parent_model"] = ( + parent_model.resource_name if parent_model else None + ) + model = aiplatform.Model.upload(**upload_arguments) + logger.info(f"Uploaded new model version with labels: {model.labels}") + + return self._vertex_model_to_registry_version(model) + + def delete_model_version( + self, + name: str, + version: str, + ) -> None: + """Delete a model version from the Vertex AI model registry. + + Args: + name: Model name + version: Version string + """ + try: + model = self._init_vertex_model(name=name, version=version) + assert isinstance(model, aiplatform.Model) + model.versioning_registry.delete_version(version) + logger.info(f"Deleted model version: {name} version {version}") + except Exception as e: + raise RuntimeError(f"Failed to delete model version: {str(e)}") + + def update_model_version( + self, + name: str, + version: str, + description: Optional[str] = None, + metadata: Optional[ModelRegistryModelMetadata] = None, + remove_metadata: Optional[List[str]] = None, + stage: Optional[ModelVersionStage] = None, + ) -> RegistryModelVersion: + """Update a model version in the Vertex AI model registry.""" + try: + parent_model = self._init_vertex_model(name=name, version=version) + assert isinstance(parent_model, aiplatform.Model) + sanitized_version = self._sanitize_label(version) + target_version = None + for v in parent_model.list(): + if v.labels.get("user_version") == sanitized_version: + target_version = v + break + if target_version is None: + raise RuntimeError( + f"Model version '{version}' for '{name}' not found." + ) + labels = target_version.labels or {} + if metadata: + metadata_dict = metadata.model_dump() + for key, value in metadata_dict.items(): + labels[self._sanitize_label(key)] = self._sanitize_label( + str(value) + ) + if remove_metadata: + for key in remove_metadata: + labels.pop(self._sanitize_label(key), None) + if stage: + labels["stage"] = stage.value.lower() + target_version.update(description=description, labels=labels) + return self.get_model_version(name, version) + except Exception as e: + raise RuntimeError(f"Failed to update model version: {str(e)}") + + def get_model_version( + self, name: str, version: str + ) -> RegistryModelVersion: + """Get a model version from the Vertex AI model registry using the version label.""" + try: + parent_model = self._init_vertex_model(name=name, version=version) + assert isinstance(parent_model, aiplatform.Model) + return self._vertex_model_to_registry_version(parent_model) + except Exception as e: + raise RuntimeError(f"Failed to get model version: {str(e)}") + + def list_model_versions( + self, + name: Optional[str] = None, + model_source_uri: Optional[str] = None, + metadata: Optional[ModelRegistryModelMetadata] = None, + stage: Optional[ModelVersionStage] = None, + count: Optional[int] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + order_by_date: Optional[str] = None, + **kwargs: Any, + ) -> List[RegistryModelVersion]: + """List model versions from the Vertex AI model registry.""" + credentials, project_id = self._get_authentication() + location = self.config.location + filter_expr = [] + if name: + filter_expr.append( + f"display_name={self._sanitize_model_display_name(name)}" + ) + if metadata: + for key, value in metadata.dict().items(): + filter_expr.append( + f"labels.{self._sanitize_label(key)}={self._sanitize_label(str(value))}" + ) + if created_after: + filter_expr.append(f"create_time>{created_after.isoformat()}") + if created_before: + filter_expr.append(f"create_time<{created_before.isoformat()}") + + filter_str = " AND ".join(filter_expr) if filter_expr else None + + try: + model = aiplatform.Model( + project=project_id, + location=location, + filter=filter_str, + credentials=credentials, + ) + versions = model.versioning_registry.list_versions() + results = [ + self._vertex_model_to_registry_version(v) for v in versions + ] + if count: + results = results[:count] + return results + except Exception as e: + raise RuntimeError(f"Failed to list model versions: {str(e)}") + + def load_model_version( + self, + name: str, + version: str, + **kwargs: Any, + ) -> Any: + """Load a model version from the Vertex AI model registry using label-based lookup.""" + try: + parent_model = self._init_vertex_model(name=name, version=version) + assert isinstance(parent_model, aiplatform.Model) + return parent_model + except Exception as e: + raise RuntimeError(f"Failed to load model version: {str(e)}") + + def get_model_uri_artifact_store( + self, + model_version: RegistryModelVersion, + ) -> str: + """Get the model URI artifact store.""" + return model_version.model_source_uri + + def _vertex_model_to_registry_version( + self, model: aiplatform.Model + ) -> RegistryModelVersion: + """Convert Vertex AI model to ZenML RegistryModelVersion. + + Args: + model: Vertex AI Model instance + + Returns: + RegistryModelVersion instance + """ + # Extract stage from labels if present + stage = ModelVersionStage.NONE + if model.labels and "stage" in model.labels: + try: + stage = ModelVersionStage(model.labels["stage"].upper()) + except ValueError: + pass + + # Get parent model for registered_model field + try: + registered_model = RegisteredModel( + name=model.display_name, + description=model.description, + metadata=model.labels, + ) + except Exception as e: + logger.warning( + f"Failed to get parent model for version: {model.resource_name}: {e}" + ) + registered_model = RegisteredModel( + name=model.display_name if model.display_name else "unknown", + description=model.description if model.description else "", + metadata=model.labels if model.labels else {}, + ) + + return RegistryModelVersion( + registered_model=registered_model, + version=model.version_id, + model_source_uri=model.uri, + model_format="Custom", # Vertex AI doesn't provide format info + description=model.description, + metadata=model.labels, + created_at=model.create_time, + last_updated_at=model.update_time, + stage=stage, + ) + + def _sanitize_model_display_name(self, name: str) -> str: + """Sanitize the model display name to conform to Vertex AI limits.""" + # Use our existing sanitizer (which converts to lowercase, replaces invalid characters, etc.) + name = self._sanitize_label(name) + if len(name) > MAX_DISPLAY_NAME_LENGTH: + logger.warning( + f"Model name '{name}' exceeds {MAX_DISPLAY_NAME_LENGTH} characters; truncating." + ) + name = name[:MAX_DISPLAY_NAME_LENGTH] + return name diff --git a/src/zenml/integrations/gcp/services/__init__.py b/src/zenml/integrations/gcp/services/__init__.py new file mode 100644 index 00000000000..392a48e9694 --- /dev/null +++ b/src/zenml/integrations/gcp/services/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the Vertex Service.""" + +from zenml.integrations.gcp.services.vertex_deployment import ( # noqa + VertexDeploymentConfig, + VertexDeploymentService, +) + +__all__ = ["VertexDeploymentConfig", "VertexDeploymentService"] \ No newline at end of file diff --git a/src/zenml/integrations/gcp/services/vertex_deployment.py b/src/zenml/integrations/gcp/services/vertex_deployment.py new file mode 100644 index 00000000000..cad3f694744 --- /dev/null +++ b/src/zenml/integrations/gcp/services/vertex_deployment.py @@ -0,0 +1,449 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Vertex AI Deployment service.""" + +import re +from datetime import datetime +from typing import Any, Dict, Generator, List, Optional, Tuple, cast + +from google.api_core import retry +from google.cloud import aiplatform +from pydantic import Field, PrivateAttr + +from zenml.client import Client +from zenml.integrations.gcp.flavors.vertex_base_config import ( + VertexAIEndpointConfig, +) +from zenml.logger import get_logger +from zenml.services import ServiceState, ServiceStatus, ServiceType +from zenml.services.service import BaseDeploymentService, ServiceConfig +from zenml.services.service_endpoint import ( + BaseServiceEndpoint, + ServiceEndpointConfig, +) + +logger = get_logger(__name__) + +# Constants +POLLING_TIMEOUT = 1800 # 30 minutes +RETRY_DEADLINE = 600 # 10 minutes +UUID_SLICE_LENGTH: int = 8 + +# Retry configuration for transient errors +retry_config = retry.Retry( + initial=1.0, # Initial delay in seconds + maximum=60.0, # Maximum delay + multiplier=2.0, # Delay multiplier + deadline=RETRY_DEADLINE, + predicate=retry.if_transient_error, +) + + +def sanitize_vertex_label(value: str) -> str: + """Sanitize a label value to comply with Vertex AI requirements. + + Args: + value: The label value to sanitize + + Returns: + Sanitized label value + """ + if not value: + return "" + + # Convert to lowercase + value = value.lower() + # Replace any character that's not lowercase letter, number, dash or underscore + value = re.sub(r"[^a-z0-9\-_]", "-", value) + # Ensure it starts with a letter/number by prepending 'x' if needed + if not value[0].isalnum(): + value = f"x{value}" + # Truncate to 63 chars to stay under limit + return value[:63] + + +class VertexDeploymentConfig(VertexAIEndpointConfig, ServiceConfig): + """Vertex AI service configurations.""" + + def get_vertex_deployment_labels(self) -> Dict[str, str]: + """Generate labels for the VertexAI deployment from the service configuration.""" + labels = self.labels or {} + labels["managed_by"] = "zenml" + if self.pipeline_name: + labels["pipeline-name"] = sanitize_vertex_label(self.pipeline_name) + if self.pipeline_step_name: + labels["step-name"] = sanitize_vertex_label( + self.pipeline_step_name + ) + if self.model_name: + labels["model-name"] = sanitize_vertex_label(self.model_name) + if self.service_name: + labels["service-name"] = sanitize_vertex_label(self.service_name) + if self.display_name: + labels["display-name"] = sanitize_vertex_label( + self.display_name + ) or sanitize_vertex_label(self.name) + return labels + + +class VertexPredictionServiceEndpointConfig(ServiceEndpointConfig): + """Vertex AI Prediction Service Endpoint.""" + + endpoint_name: Optional[str] = None + deployed_model_id: Optional[str] = None + endpoint_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + state: Optional[str] = None + + +class VertexServiceStatus(ServiceStatus): + """Vertex AI service status.""" + + +class VertexPredictionServiceEndpoint(BaseServiceEndpoint): + """Vertex AI Prediction Service Endpoint.""" + + config: VertexPredictionServiceEndpointConfig + + +class VertexDeploymentService(BaseDeploymentService): + """Vertex AI model deployment service.""" + + SERVICE_TYPE = ServiceType( + name="vertex-deployment", + type="model-serving", + flavor="vertex", + description="Vertex AI inference endpoint prediction service", + ) + config: VertexDeploymentConfig + status: VertexServiceStatus = Field( + default_factory=lambda: VertexServiceStatus() + ) + _project_id: Optional[str] = PrivateAttr(default=None) + _credentials: Optional[Any] = PrivateAttr(default=None) + + def _initialize_gcp_clients(self) -> None: + """Initialize GCP clients with consistent credentials.""" + from zenml.integrations.gcp.model_deployers.vertex_model_deployer import ( + VertexModelDeployer, + ) + + model_deployer = cast( + VertexModelDeployer, Client().active_stack.model_deployer + ) + + # Get credentials from model deployer + self._credentials, self._project_id = ( + model_deployer._get_authentication() + ) + + def __init__(self, config: VertexDeploymentConfig, **attrs: Any): + """Initialize the Vertex AI deployment service.""" + super().__init__(config=config, **attrs) + self._initialize_gcp_clients() + + @property + def prediction_url(self) -> Optional[str]: + """The prediction URI exposed by the prediction service.""" + endpoints = self.get_endpoints() + if not endpoints: + return None + endpoint = endpoints[0] + return f"https://{self.config.location}-aiplatform.googleapis.com/v1/{endpoint.resource_name}" + + def get_endpoints(self) -> List[aiplatform.Endpoint]: + """Get all endpoints for the current project and location. + + Returns: + List of Vertex AI endpoints + """ + try: + # Use proper filtering and pagination + display_name = self.config.name or self.config.display_name + assert display_name is not None + display_name = sanitize_vertex_label(display_name) + return list( + aiplatform.Endpoint.list( + filter=f"labels.managed_by=zenml AND labels.display-name={display_name}", + project=self._project_id, + location=self.config.location, + credentials=self._credentials, + ) + ) + except Exception as e: + logger.error(f"Failed to list endpoints: {e}") + return [] + + def _generate_endpoint_name(self) -> str: + """Generate a unique name for the Vertex AI Inference Endpoint. + + Returns: + Generated endpoint name + """ + # Make name more descriptive and conformant + sanitized_model_name = sanitize_vertex_label( + self.config.display_name or self.config.name + ) + return f"{sanitized_model_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}" + + def _get_model_id(self, name: str) -> str: + """Helper to construct a full model ID from a given model name.""" + return f"projects/{self._project_id}/locations/{self.config.location}/models/{name}" + + def _verify_model_exists(self) -> aiplatform.Model: + """Verify the model exists and return it. + + Returns: + Vertex AI Model instance + + Raises: + RuntimeError: If model not found + """ + if self.config.model_name.startswith("projects/"): + model_name = self.config.model_name + else: + model_name = self._get_model_id(self.config.model_name) + # Remove version suffix if present + if "@" in model_name: + model_name = model_name.split("@")[0] + logger.info(f"Model name: {model_name}") + model = aiplatform.Model( + model_name=model_name, + project=self._project_id, + location=self.config.location, + credentials=self._credentials, + ) + logger.info(f"Found model to deploy: {model.resource_name}") + return model + + def _deploy_model( + self, model: aiplatform.Model, endpoint: aiplatform.Endpoint + ) -> None: + """Deploy model to Vertex AI endpoint.""" + # Prepare deployment configuration + deploy_kwargs = { + "model": model, + "deployed_model_display_name": self.config.display_name + or self.config.name, + "traffic_percentage": 100, + "sync": False, + } + logger.info( + f"Deploying model to endpoint with kwargs: {deploy_kwargs}" + ) + # Add container configuration if specified + if self.config.container: + deploy_kwargs.update( + { + "container_image_uri": self.config.container.image_uri, + "container_ports": self.config.container.ports, + "container_predict_route": self.config.container.predict_route, + "container_health_route": self.config.container.health_route, + "container_env": self.config.container.env, + } + ) + + # Add resource configuration if specified + if self.config.resources: + deploy_kwargs.update( + { + "machine_type": self.config.resources.machine_type, + "min_replica_count": self.config.resources.min_replica_count, + "max_replica_count": self.config.resources.max_replica_count, + "accelerator_type": self.config.resources.accelerator_type, + "accelerator_count": self.config.resources.accelerator_count, + } + ) + + # Add explanation configuration if specified + if self.config.explanation: + deploy_kwargs.update( + { + "explanation_metadata": self.config.explanation.metadata, + "explanation_parameters": self.config.explanation.parameters, + } + ) + + # Add service account if specified + if self.config.service_account: + deploy_kwargs["service_account"] = self.config.service_account + + # Add network configuration if specified + if self.config.network: + deploy_kwargs["network"] = self.config.network + + # Add encryption key if specified + if self.config.encryption_spec_key_name: + deploy_kwargs["encryption_spec_key_name"] = ( + self.config.encryption_spec_key_name + ) + + # Deploy model + logger.info( + f"Deploying model to endpoint with kwargs: {deploy_kwargs}" + ) + endpoint.deploy(**deploy_kwargs) + + def provision(self) -> None: + """Provision or update remote Vertex AI deployment instance.""" + # First verify model exists + model = self._verify_model_exists() + logger.info(f"Found model to deploy: {model.resource_name}") + # Get or create endpoint + if self.config.existing_endpoint: + endpoint = aiplatform.Endpoint( + endpoint_name=self.config.existing_endpoint, + location=self.config.location, + credentials=self._credentials, + ) + logger.info(f"Using existing endpoint: {endpoint.resource_name}") + else: + endpoint_name = self._generate_endpoint_name() + endpoint = aiplatform.Endpoint.create( + display_name=endpoint_name, + location=self.config.location, + encryption_spec_key_name=self.config.encryption_spec_key_name, + labels=self.config.get_vertex_deployment_labels(), + credentials=self._credentials, + ) + logger.info(f"Created new endpoint: {endpoint.resource_name}") + # Deploy model with retries for transient errors + try: + self._deploy_model(model, endpoint) + + logger.info( + f"Model {model.resource_name} deployed to endpoint {endpoint.resource_name}" + ) + except Exception as e: + self.status.update_state( + ServiceState.ERROR, f"Deployment failed: {str(e)}" + ) + raise + + self.status.update_state(ServiceState.ACTIVE) + + logger.info( + f"Deployment completed successfully. " + f"Endpoint: {endpoint.resource_name}" + ) + + def deprovision(self, force: bool = False) -> None: + """Deprovision the Vertex AI deployment. + + Args: + force: Whether to force deprovision + """ + endpoints = self.get_endpoints() + if endpoints: + try: + endpoint = endpoints[0] + endpoint.undeploy_all() + endpoint.delete() + logger.info( + f"Deprovisioned endpoint: {endpoint.resource_name}" + ) + self.status.update_state(ServiceState.INACTIVE) + except Exception as e: + logger.error(f"Failed to deprovision endpoint: {e}") + self.status.update_state( + ServiceState.ERROR, f"Failed to deprovision endpoint: {e}" + ) + else: + try: + endpoint = aiplatform.Endpoint( + endpoint_name=self._generate_endpoint_name(), + location=self.config.location, + credentials=self._credentials, + ) + + # Undeploy model + endpoint.undeploy_all() + + # Delete endpoint if we created it + if not self.config.existing_endpoint: + endpoint.delete() + + logger.info( + f"Deprovisioned endpoint: {endpoint.resource_name}" + ) + + self.status.update_state(ServiceState.INACTIVE) + + except Exception as e: + error_msg = f"Failed to deprovision deployment: {str(e)}" + if not force: + logger.error(error_msg) + self.status.update_state(ServiceState.ERROR, error_msg) + raise RuntimeError(error_msg) + else: + logger.warning( + f"Error during forced deprovision (ignoring): {error_msg}" + ) + self.status.update_state(ServiceState.INACTIVE) + + def get_logs( + self, + follow: bool = False, + tail: Optional[int] = None, + ) -> Generator[str, bool, None]: + """Retrieve logs for the Vertex AI deployment (not supported). + + Yields: + Log entries as strings, but logs are not supported for Vertex AI. + """ + logger.warning("Logs are not supported for Vertex AI") + yield from () + + def check_status(self) -> Tuple[ServiceState, str]: + """Check the status of the deployment by validating if an endpoint exists and if it has deployed models. + + Returns: + A tuple containing the deployment's state and a status message. + """ + try: + endpoints = self.get_endpoints() + if not endpoints: + return ServiceState.INACTIVE, "No endpoint found." + + endpoint = endpoints[0] + try: + endpoint.reload() + except Exception as e: + logger.warning(f"Failed to reload endpoint: {e}") + + deployed_models = [] + if hasattr(endpoint, "list_models"): + try: + deployed_models = endpoint.list_models() + except Exception as e: + logger.warning(f"Failed to list models for endpoint: {e}") + elif hasattr(endpoint, "deployed_models"): + deployed_models = endpoint.deployed_models or [] + + if deployed_models and len(deployed_models) > 0: + return ServiceState.ACTIVE, "" + else: + return ( + ServiceState.PENDING_STARTUP, + "Endpoint deployment is in progress.", + ) + except Exception as e: + return ServiceState.ERROR, f"Deployment check failed: {e}" + + @property + def is_running(self) -> bool: + """Check if the service is running.""" + self.update_status() + return self.status.state == ServiceState.ACTIVE diff --git a/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py b/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py index b11f7fe7080..df8cf57f304 100644 --- a/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py +++ b/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py @@ -13,8 +13,10 @@ # permissions and limitations under the License. """Implementation of the sklearn materializer.""" +import os from typing import Any, ClassVar, Tuple, Type +import cloudpickle from sklearn.base import ( BaseEstimator, BiclusterMixin, @@ -29,13 +31,20 @@ ) from zenml.enums import ArtifactType +from zenml.environment import Environment +from zenml.logger import get_logger from zenml.materializers.cloudpickle_materializer import ( + DEFAULT_FILENAME, CloudpickleMaterializer, ) +logger = get_logger(__name__) + +SKLEARN_MODEL_FILENAME = "model.pkl" + class SklearnMaterializer(CloudpickleMaterializer): - """Materializer to read data to and from sklearn.""" + """Materializer to read data to and from sklearn with backward compatibility.""" ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( BaseEstimator, @@ -50,3 +59,63 @@ class SklearnMaterializer(CloudpickleMaterializer): TransformerMixin, ) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL + + def load(self, data_type: Type[Any]) -> Any: + """Reads a sklearn model from pickle file with backward compatibility. + + Args: + data_type: The data type of the artifact. + + Returns: + The loaded sklearn model. + """ + # First try to load from model.pkl + model_filepath = os.path.join(self.uri, SKLEARN_MODEL_FILENAME) + artifact_filepath = os.path.join(self.uri, DEFAULT_FILENAME) + + # Check which file exists and load accordingly + if self.artifact_store.exists(model_filepath): + filepath = model_filepath + elif self.artifact_store.exists(artifact_filepath): + logger.info( + f"Loading from legacy filepath {artifact_filepath}. Future saves " + f"will use {model_filepath}" + ) + filepath = artifact_filepath + else: + raise FileNotFoundError( + f"Neither {model_filepath} nor {artifact_filepath} found in artifact store" + ) + + # validate python version before loading + source_python_version = self._load_python_version() + current_python_version = Environment().python_version() + if ( + source_python_version != "unknown" + and source_python_version != current_python_version + ): + logger.warning( + f"Your artifact was materialized under Python version " + f"'{source_python_version}' but you are currently using " + f"'{current_python_version}'. This might cause unexpected " + "behavior since pickle is not reproducible across Python " + "versions. Attempting to load anyway..." + ) + + # Load the model + with self.artifact_store.open(filepath, "rb") as fid: + return cloudpickle.load(fid) + + def save(self, data: Any) -> None: + """Saves a sklearn model to pickle file using the new filename. + + Args: + data: The sklearn model to save. + """ + # Save python version for validation on loading + self._save_python_version() + + # Save using the new filename + filepath = os.path.join(self.uri, SKLEARN_MODEL_FILENAME) + with self.artifact_store.open(filepath, "wb") as fid: + cloudpickle.dump(data, fid) diff --git a/src/zenml/model_deployers/base_model_deployer.py b/src/zenml/model_deployers/base_model_deployer.py index 40a65128f26..814e4f28175 100644 --- a/src/zenml/model_deployers/base_model_deployer.py +++ b/src/zenml/model_deployers/base_model_deployer.py @@ -32,6 +32,7 @@ from zenml.logger import get_logger from zenml.services import BaseService, ServiceConfig from zenml.services.service import BaseDeploymentService +from zenml.services.service_status import ServiceState from zenml.services.service_type import ServiceType from zenml.stack import StackComponent from zenml.stack.flavor import Flavor @@ -180,6 +181,12 @@ def deploy_model( logger.info( f"Existing model server found for {config.name or config.model_name} with the exact same configuration. Returning the existing service named {services[0].config.service_name}." ) + status, _ = services[0].check_status() + if status != ServiceState.ACTIVE: + logger.info( + f"Service found for {config.name or config.model_name} is not active. Starting the service." + ) + services[0].start(timeout=timeout) return services[0] else: # Find existing model server diff --git a/src/zenml/model_registries/base_model_registry.py b/src/zenml/model_registries/base_model_registry.py index 578d97d396c..b2da8c358e2 100644 --- a/src/zenml/model_registries/base_model_registry.py +++ b/src/zenml/model_registries/base_model_registry.py @@ -70,6 +70,15 @@ class ModelRegistryModelMetadata(BaseModel): zenml_step_name: Optional[str] = None zenml_workspace: Optional[str] = None + @property + def managed_by(self) -> str: + """Returns the managed by attribute. + + Returns: + The managed by attribute. + """ + return "zenml" + @property def custom_attributes(self) -> Dict[str, str]: """Returns a dictionary of custom attributes. diff --git a/src/zenml/services/service.py b/src/zenml/services/service.py index 7b607aae611..0077a3a945f 100644 --- a/src/zenml/services/service.py +++ b/src/zenml/services/service.py @@ -35,6 +35,7 @@ from zenml.console import console from zenml.logger import get_logger +from zenml.model.model import Model from zenml.services.service_endpoint import BaseServiceEndpoint from zenml.services.service_monitor import HTTPEndpointHealthMonitor from zenml.services.service_status import ServiceState, ServiceStatus @@ -109,6 +110,7 @@ class ServiceConfig(BaseTypedModel): pipeline_name: name of the pipeline that spun up the service pipeline_step_name: name of the pipeline step that spun up the service run_name: name of the pipeline run that spun up the service. + zenml_model: the ZenML model object to be deployed. """ name: str = "" @@ -118,6 +120,7 @@ class ServiceConfig(BaseTypedModel): model_name: str = "" model_version: str = "" service_name: str = "" + zenml_model: Optional[Model] = None # TODO: In Pydantic v2, the `model_` is a protected namespaces for all # fields defined under base models. If not handled, this raises a warning.