diff --git a/changes/5672.feature.md b/changes/5672.feature.md new file mode 100644 index 00000000000..8d38f74997a --- /dev/null +++ b/changes/5672.feature.md @@ -0,0 +1 @@ +Implement API Layer of Model Deployment diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 5c88496d32b..86df071be67 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -3009,10 +3009,10 @@ type ModelDeployment implements Node metadata: ModelDeploymentMetadata! networkAccess: ModelDeploymentNetworkAccess! revision: ModelRevision - scalingRule: ScalingRule! - replicaState: ReplicaState! defaultDeploymentStrategy: DeploymentStrategy! createdUser: UserNode! + scalingRule: ScalingRule! + replicaState: ReplicaState! revisionHistory(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! } @@ -3046,10 +3046,10 @@ type ModelDeploymentMetadata name: String! status: DeploymentStatus! tags: [String!]! - project: GroupNode! - domain: DomainNode! createdAt: DateTime! updatedAt: DateTime! + project: GroupNode! + domain: DomainNode! } """Added in 25.13.0""" @@ -3084,9 +3084,9 @@ input ModelDeploymentNetworkAccessInput type ModelMountConfig @join__type(graph: STRAWBERRY) { - vfolder: VirtualFolderNode! mountDestination: String! definitionPath: String! + vfolder: VirtualFolderNode! } """Added in 25.13.0""" @@ -3105,7 +3105,6 @@ type ModelReplica implements Node { """The Globally Unique ID of this object""" id: ID! - revision: ModelRevision! """ This represents whether the replica has been checked and its health state. @@ -3138,6 +3137,7 @@ type ModelReplica implements Node The session ID associated with the replica. This can be null right after replica creation. """ session: ComputeSessionNode! + revision: ModelRevision! } """Added in 25.13.0""" @@ -3176,8 +3176,8 @@ type ModelRevision implements Node modelRuntimeConfig: ModelRuntimeConfig! modelMountConfig: ModelMountConfig! extraMounts: ExtraVFolderMountConnection! - image: ImageNode! createdAt: DateTime! + image: ImageNode! } """Added in 25.13.0""" @@ -4103,7 +4103,9 @@ type Mutation """Added in 25.13.0""" addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """ + Added in 25.13.0. Create model revision which is not attached to any deployment. + """ createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload! @join__field(graph: STRAWBERRY) """Added in 25.14.0""" @@ -5066,8 +5068,6 @@ type ReservoirRegistryEdge type ResourceConfig @join__type(graph: STRAWBERRY) { - resourceGroup: ScalingGroupNode! - """ Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}" """ @@ -5077,6 +5077,7 @@ type ResourceConfig Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}" """ resourceOpts: JSONString + resourceGroup: ScalingGroupNode! } """Added in 25.13.0""" diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 64c9ae6dc1f..2d44f5acab9 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -710,10 +710,10 @@ type ModelDeployment implements Node { metadata: ModelDeploymentMetadata! networkAccess: ModelDeploymentNetworkAccess! revision: ModelRevision - scalingRule: ScalingRule! - replicaState: ReplicaState! defaultDeploymentStrategy: DeploymentStrategy! createdUser: UserNode! + scalingRule: ScalingRule! + replicaState: ReplicaState! revisionHistory(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! } @@ -741,10 +741,10 @@ type ModelDeploymentMetadata { name: String! status: DeploymentStatus! tags: [String!]! - project: GroupNode! - domain: DomainNode! createdAt: DateTime! updatedAt: DateTime! + project: GroupNode! + domain: DomainNode! } """Added in 25.13.0""" @@ -771,9 +771,9 @@ input ModelDeploymentNetworkAccessInput { """Added in 25.13.0""" type ModelMountConfig { - vfolder: VirtualFolderNode! mountDestination: String! definitionPath: String! + vfolder: VirtualFolderNode! } """Added in 25.13.0""" @@ -787,7 +787,6 @@ input ModelMountConfigInput { type ModelReplica implements Node { """The Globally Unique ID of this object""" id: ID! - revision: ModelRevision! """ This represents whether the replica has been checked and its health state. @@ -820,6 +819,7 @@ type ModelReplica implements Node { The session ID associated with the replica. This can be null right after replica creation. """ session: ComputeSessionNode! + revision: ModelRevision! } """Added in 25.13.0""" @@ -851,8 +851,8 @@ type ModelRevision implements Node { modelRuntimeConfig: ModelRuntimeConfig! modelMountConfig: ModelMountConfig! extraMounts: ExtraVFolderMountConnection! - image: ImageNode! createdAt: DateTime! + image: ImageNode! } """Added in 25.13.0""" @@ -961,7 +961,9 @@ type Mutation { """Added in 25.13.0""" addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload! - """Added in 25.13.0""" + """ + Added in 25.13.0. Create model revision which is not attached to any deployment. + """ createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload! """Added in 25.14.0""" @@ -1272,8 +1274,6 @@ type ReservoirRegistryEdge { """Added in 25.13.0""" type ResourceConfig { - resourceGroup: ScalingGroupNode! - """ Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}" """ @@ -1283,6 +1283,7 @@ type ResourceConfig { Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}" """ resourceOpts: JSONString + resourceGroup: ScalingGroupNode! } """Added in 25.13.0""" diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index cc0a6581c5b..01ab69f2d0a 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -165,6 +165,7 @@ class ErrorDomain(enum.StrEnum): PERMISSION = "permission" METRIC = "metric" STORAGE_PROXY = "storage-proxy" + MODEL_DEPLOYMENT = "model-deployment" class ErrorOperation(enum.StrEnum): @@ -643,3 +644,16 @@ def error_code(cls) -> ErrorCode: operation=ErrorOperation.READ, error_detail=ErrorDetail.NOT_FOUND, ) + + +class ModelDeploymentUnavailableError(BackendAIError, web.HTTPServiceUnavailable): + error_type = "https://api.backend.ai/probs/model-deployment-unavailable" + error_title = "Model Deployment Unavailable" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.MODEL_DEPLOYMENT, + operation=ErrorOperation.EXECUTE, + error_detail=ErrorDetail.UNAVAILABLE, + ) diff --git a/src/ai/backend/manager/api/gql/base.py b/src/ai/backend/manager/api/gql/base.py index 6e58e063d36..a2d0778546a 100644 --- a/src/ai/backend/manager/api/gql/base.py +++ b/src/ai/backend/manager/api/gql/base.py @@ -5,6 +5,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional, Type, cast +import graphene import orjson import strawberry from graphql import StringValueNode @@ -160,7 +161,17 @@ def from_resource_slot(resource_slot: ResourceSlot) -> JSONString: return JSONString.serialize(resource_slot.to_json()) -def to_global_id(type_: Type[Any], local_id: uuid.UUID | str) -> str: +def to_global_id( + type_: Type[Any], local_id: uuid.UUID | str, is_target_graphene_object: bool = False +) -> str: + if is_target_graphene_object: + # For compatibility with existing Graphene-based global IDs + if not issubclass(type_, graphene.ObjectType): + raise TypeError( + "type_ must be a graphene ObjectType when is_target_graphene_object is True." + ) + typename = type_.__name__ + return base64(f"{typename}:{local_id}") if not has_object_definition(type_): raise TypeError("type_ must be a Strawberry object type (Node or Edge).") typename = get_object_definition(type_, strict=True).name diff --git a/src/ai/backend/manager/api/gql/model_deployment/access_token.py b/src/ai/backend/manager/api/gql/model_deployment/access_token.py index 92429f17b99..80941222843 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/access_token.py +++ b/src/ai/backend/manager/api/gql/model_deployment/access_token.py @@ -1,16 +1,27 @@ -from datetime import datetime, timedelta +from collections.abc import Sequence +from datetime import datetime +from typing import Self from uuid import UUID import strawberry from strawberry import ID, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.exception import ModelDeploymentUnavailableError from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.access_token import ModelDeploymentAccessTokenCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentAccessTokenData +from ai.backend.manager.services.deployment.actions.access_token.create_access_token import ( + CreateAccessTokenAction, +) +from ai.backend.manager.services.deployment.actions.access_token.get_access_tokens_by_deployment_id import ( + GetAccessTokensByDeploymentIdAction, +) @strawberry.type class AccessToken(Node): - id: NodeID + id: NodeID[str] token: str = strawberry.field(description="Added in 25.13.0: The access token.") created_at: datetime = strawberry.field( description="Added in 25.13.0: The creation timestamp of the access token." @@ -19,6 +30,34 @@ class AccessToken(Node): description="Added in 25.13.0: The expiration timestamp of the access token." ) + @classmethod + def from_dataclass(cls, data: ModelDeploymentAccessTokenData) -> Self: + return cls( + id=ID(str(data.id)), + token=data.token, + created_at=data.created_at, + valid_until=data.valid_until, + ) + + @classmethod + async def batch_load_by_deployment_ids( + cls, ctx: StrawberryGQLContext, deployment_ids: Sequence[UUID] + ) -> list[list[ModelDeploymentAccessTokenData]]: + """Batch load access tokens by deployment IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + results = [] + for deployment_id in deployment_ids: + result = await processor.get_access_tokens_by_deployment_id.wait_for_complete( + GetAccessTokensByDeploymentIdAction(deployment_id=deployment_id) + ) + results.append(result.data if result else []) + return results + AccessTokenEdge = Edge[AccessToken] @@ -32,42 +71,6 @@ def __init__(self, *args, count: int, **kwargs): self.count = count -mock_access_token_1 = AccessToken( - id=UUID("13cd8325-9307-49e4-94eb-ded2581363f8"), - token="mock-token-1", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=12), -) - -mock_access_token_2 = AccessToken( - id=UUID("dc1a223a-7437-4e6f-aedf-23417d0486dd"), - token="mock-token-2", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=1), -) - -mock_access_token_3 = AccessToken( - id=UUID("39f8b49e-0ddf-4dfb-92d6-003c771684b7"), - token="mock-token-3", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=100), -) - -mock_access_token_4 = AccessToken( - id=UUID("85a6ed1e-133b-4f58-9c06-f667337c6111"), - token="mock-token-4", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=10), -) - -mock_access_token_5 = AccessToken( - id=UUID("c42f8578-b31d-4203-b858-93f93b4b9549"), - token="mock-token-5", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=3), -) - - @strawberry.input class CreateAccessTokenInput: model_deployment_id: ID = strawberry.field( @@ -77,6 +80,12 @@ class CreateAccessTokenInput: description="Added in 25.13.0: The expiration timestamp of the access token." ) + def to_creator(self) -> "ModelDeploymentAccessTokenCreator": + return ModelDeploymentAccessTokenCreator( + model_deployment_id=UUID(self.model_deployment_id), + valid_until=self.valid_until, + ) + @strawberry.type class CreateAccessTokenPayload: @@ -87,4 +96,9 @@ class CreateAccessTokenPayload: async def create_access_token( input: CreateAccessTokenInput, info: Info[StrawberryGQLContext] ) -> CreateAccessTokenPayload: - return CreateAccessTokenPayload(access_token=mock_access_token_1) + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + result = await deployment_processor.create_access_token.wait_for_complete( + action=CreateAccessTokenAction(input.to_creator()) + ) + return CreateAccessTokenPayload(access_token=AccessToken.from_dataclass(result.data)) diff --git a/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py b/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py index c133ff32d91..a8c629af7cc 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py +++ b/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py @@ -1,14 +1,28 @@ -from datetime import datetime, timedelta +from datetime import datetime from decimal import Decimal from enum import StrEnum -from typing import Optional +from typing import Optional, Self from uuid import UUID import strawberry from strawberry import ID, Info from strawberry.relay import Node, NodeID +from ai.backend.common.types import AutoScalingMetricSource as CommonAutoScalingMetricSource from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.scale import ModelDeploymentAutoScalingRuleCreator +from ai.backend.manager.data.deployment.scale_modifier import ModelDeploymentAutoScalingRuleModifier +from ai.backend.manager.data.deployment.types import ModelDeploymentAutoScalingRuleData +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.create_auto_scaling_rule import ( + CreateAutoScalingRuleAction, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.delete_auto_scaling_rule import ( + DeleteAutoScalingRuleAction, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.update_auto_scaling_rule import ( + UpdateAutoScalingRuleAction, +) +from ai.backend.manager.types import OptionalState @strawberry.enum(description="Added in 25.1.0") @@ -19,7 +33,7 @@ class AutoScalingMetricSource(StrEnum): @strawberry.type class AutoScalingRule(Node): - id: NodeID + id: NodeID[str] metric_source: AutoScalingMetricSource = strawberry.field( description="Added in 25.13.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)" @@ -50,6 +64,22 @@ class AutoScalingRule(Node): created_at: datetime last_triggered_at: datetime + @classmethod + def from_dataclass(cls, data: ModelDeploymentAutoScalingRuleData) -> Self: + return cls( + id=ID(str(data.id)), + metric_source=AutoScalingMetricSource(data.metric_source.name), + metric_name=data.metric_name, + min_threshold=data.min_threshold, + max_threshold=data.max_threshold, + step_size=data.step_size, + time_window=data.time_window, + min_replicas=data.min_replicas, + max_replicas=data.max_replicas, + created_at=data.created_at, + last_triggered_at=data.last_triggered_at, + ) + # Input Types @strawberry.input @@ -64,6 +94,19 @@ class CreateAutoScalingRuleInput: min_replicas: Optional[int] max_replicas: Optional[int] + def to_creator(self) -> ModelDeploymentAutoScalingRuleCreator: + return ModelDeploymentAutoScalingRuleCreator( + model_deployment_id=UUID(self.model_deployment_id), + metric_source=CommonAutoScalingMetricSource(self.metric_source.lower()), + metric_name=self.metric_name, + min_threshold=self.min_threshold, + max_threshold=self.max_threshold, + step_size=self.step_size, + time_window=self.time_window, + min_replicas=self.min_replicas, + max_replicas=self.max_replicas, + ) + @strawberry.input class UpdateAutoScalingRuleInput: @@ -77,6 +120,26 @@ class UpdateAutoScalingRuleInput: min_replicas: Optional[int] max_replicas: Optional[int] + def to_action(self) -> UpdateAutoScalingRuleAction: + optional_state_metric_source = OptionalState[CommonAutoScalingMetricSource].nop() + if isinstance(self.metric_source, AutoScalingMetricSource): + optional_state_metric_source = OptionalState[CommonAutoScalingMetricSource].update( + CommonAutoScalingMetricSource(self.metric_source) + ) + return UpdateAutoScalingRuleAction( + auto_scaling_rule_id=UUID(self.id), + modifier=ModelDeploymentAutoScalingRuleModifier( + metric_source=optional_state_metric_source, + metric_name=OptionalState[str].from_graphql(self.metric_name), + min_threshold=OptionalState[Decimal].from_graphql(self.min_threshold), + max_threshold=OptionalState[Decimal].from_graphql(self.max_threshold), + step_size=OptionalState[int].from_graphql(self.step_size), + time_window=OptionalState[int].from_graphql(self.time_window), + min_replicas=OptionalState[int].from_graphql(self.min_replicas), + max_replicas=OptionalState[int].from_graphql(self.max_replicas), + ), + ) + @strawberry.input class DeleteAutoScalingRuleInput: @@ -99,84 +162,31 @@ class DeleteAutoScalingRulePayload: id: ID -mock_scaling_rule_0 = AutoScalingRule( - id=UUID("77117a41-87f3-43b7-ba24-40dd5e978720"), - metric_source=AutoScalingMetricSource.KERNEL, - metric_name="memory_usage", - min_threshold=None, - max_threshold=Decimal("90"), - step_size=1, - time_window=120, - min_replicas=1, - max_replicas=3, - created_at=datetime.now() - timedelta(days=15), - last_triggered_at=datetime.now() - timedelta(hours=6), -) - -mock_scaling_rule_1 = AutoScalingRule( - id=UUID("7ff8c1f5-cf8c-4ea2-911c-24ca0f4c2efb"), - metric_source=AutoScalingMetricSource.KERNEL, - metric_name="cpu_usage", - min_threshold=None, - max_threshold=Decimal("80"), - step_size=1, - time_window=300, - min_replicas=1, - max_replicas=5, - created_at=datetime.now() - timedelta(days=10), - last_triggered_at=datetime.now() - timedelta(hours=2), -) - -mock_scaling_rule_2 = AutoScalingRule( - id=UUID("483e2158-e089-482b-8cef-260805649cf1"), - metric_source=AutoScalingMetricSource.INFERENCE_FRAMEWORK, - metric_name="requests_per_second", - min_threshold=None, - max_threshold=Decimal("1000"), - step_size=2, - time_window=600, - min_replicas=2, - max_replicas=10, - created_at=datetime.now() - timedelta(days=5), - last_triggered_at=datetime.now() - timedelta(hours=12), -) - - @strawberry.mutation(description="Added in 25.13.0") async def create_auto_scaling_rule( input: CreateAutoScalingRuleInput, info: Info[StrawberryGQLContext] ) -> CreateAutoScalingRulePayload: - return CreateAutoScalingRulePayload(auto_scaling_rule=mock_scaling_rule_0) + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + result = await deployment_processor.create_auto_scaling_rule.wait_for_complete( + action=CreateAutoScalingRuleAction(input.to_creator()) + ) + return CreateAutoScalingRulePayload( + auto_scaling_rule=AutoScalingRule.from_dataclass(result.data) + ) @strawberry.mutation(description="Added in 25.13.0") async def update_auto_scaling_rule( input: UpdateAutoScalingRuleInput, info: Info[StrawberryGQLContext] ) -> UpdateAutoScalingRulePayload: + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + action_result = await deployment_processor.update_auto_scaling_rule.wait_for_complete( + input.to_action() + ) return UpdateAutoScalingRulePayload( - auto_scaling_rule=AutoScalingRule( - id=input.id, - metric_source=input.metric_source - if input.metric_source - else mock_scaling_rule_1.metric_source, - metric_name=input.metric_name if input.metric_name else mock_scaling_rule_1.metric_name, - min_threshold=input.min_threshold - if input.min_threshold - else mock_scaling_rule_1.min_threshold, - max_threshold=input.max_threshold - if input.max_threshold - else mock_scaling_rule_1.max_threshold, - step_size=input.step_size if input.step_size else mock_scaling_rule_1.step_size, - time_window=input.time_window if input.time_window else mock_scaling_rule_1.time_window, - min_replicas=input.min_replicas - if input.min_replicas - else mock_scaling_rule_1.min_replicas, - max_replicas=input.max_replicas - if input.max_replicas - else mock_scaling_rule_1.max_replicas, - created_at=datetime.now(), - last_triggered_at=datetime.now(), - ) + auto_scaling_rule=AutoScalingRule.from_dataclass(action_result.data) ) @@ -184,4 +194,9 @@ async def update_auto_scaling_rule( async def delete_auto_scaling_rule( input: DeleteAutoScalingRuleInput, info: Info[StrawberryGQLContext] ) -> DeleteAutoScalingRulePayload: - return DeleteAutoScalingRulePayload(id=input.id) + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + _ = await deployment_processor.delete_auto_scaling_rule.wait_for_complete( + DeleteAutoScalingRuleAction(auto_scaling_rule_id=UUID(input.id)) + ) + return DeleteAutoScalingRulePayload(id=ID(input.id)) diff --git a/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py b/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py index 34f42693a5b..510eb8e604c 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py +++ b/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py @@ -1,10 +1,13 @@ -from datetime import datetime, timedelta +from collections.abc import Sequence +from datetime import datetime from enum import StrEnum from typing import AsyncGenerator, Optional from uuid import UUID, uuid4 import strawberry +from aiotools import apartial from strawberry import ID, Info +from strawberry.dataloader import DataLoader from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo from ai.backend.common.data.model_deployment.types import ( @@ -13,34 +16,61 @@ from ai.backend.common.data.model_deployment.types import ( ModelDeploymentStatus as CommonDeploymentStatus, ) -from ai.backend.manager.api.gql.base import OrderDirection, StringFilter -from ai.backend.manager.api.gql.domain import Domain, mock_domain +from ai.backend.common.exception import ModelDeploymentUnavailableError +from ai.backend.manager.api.gql.base import ( + OrderDirection, + StringFilter, + resolve_global_id, + to_global_id, +) +from ai.backend.manager.api.gql.domain import Domain from ai.backend.manager.api.gql.model_deployment.access_token import ( + AccessToken, AccessTokenConnection, AccessTokenEdge, - mock_access_token_1, - mock_access_token_2, - mock_access_token_3, - mock_access_token_4, - mock_access_token_5, ) from ai.backend.manager.api.gql.model_deployment.auto_scaling_rule import ( AutoScalingRule, - mock_scaling_rule_1, - mock_scaling_rule_2, ) from ai.backend.manager.api.gql.model_deployment.model_replica import ( + ModelReplica, ModelReplicaConnection, - ModelReplicaEdge, ReplicaFilter, ReplicaOrderBy, - mock_model_replica_1, - mock_model_replica_2, - mock_model_replica_3, ) -from ai.backend.manager.api.gql.project import Project, mock_project +from ai.backend.manager.api.gql.project import Project from ai.backend.manager.api.gql.types import StrawberryGQLContext -from ai.backend.manager.api.gql.user import User, mock_user_id +from ai.backend.manager.api.gql.user import User +from ai.backend.manager.data.deployment.creator import NewDeploymentCreator +from ai.backend.manager.data.deployment.modifier import NewDeploymentModifier +from ai.backend.manager.data.deployment.types import ( + DeploymentMetadata, + DeploymentNetworkSpec, + ModelDeploymentData, + ModelDeploymentMetadataInfo, + ReplicaSpec, + ReplicaStateData, +) +from ai.backend.manager.models.gql_models.domain import DomainNode +from ai.backend.manager.models.gql_models.group import GroupNode +from ai.backend.manager.models.gql_models.user import UserNode +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.get_auto_scaling_rule_by_deployment_id import ( + GetAutoScalingRulesByDeploymentIdAction, +) +from ai.backend.manager.services.deployment.actions.create_deployment import ( + CreateDeploymentAction, +) +from ai.backend.manager.services.deployment.actions.destroy_deployment import ( + DestroyDeploymentAction, +) +from ai.backend.manager.services.deployment.actions.get_deployment import GetDeploymentAction +from ai.backend.manager.services.deployment.actions.get_replicas_by_deployment_id import ( + GetReplicasByDeploymentIdAction, +) +from ai.backend.manager.services.deployment.actions.list_deployments import ListDeploymentsAction +from ai.backend.manager.services.deployment.actions.sync_replicas import SyncReplicaAction +from ai.backend.manager.services.deployment.actions.update_deployment import UpdateDeploymentAction +from ai.backend.manager.types import OptionalState, PaginationOptions, TriState from .model_revision import ( CreateModelRevisionInput, @@ -49,9 +79,6 @@ ModelRevisionEdge, ModelRevisionFilter, ModelRevisionOrderBy, - mock_model_revision_1, - mock_model_revision_2, - mock_model_revision_3, ) DeploymentStatus = strawberry.enum( @@ -81,8 +108,9 @@ class DeploymentStrategy: @strawberry.type(description="Added in 25.13.0") class ReplicaState: - desired_replica_count: int + _deployment_id: strawberry.Private[UUID] _replica_ids: strawberry.Private[list[UUID]] + desired_replica_count: int @strawberry.field async def replicas( @@ -97,37 +125,141 @@ async def replicas( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelReplicaConnection: - return ModelReplicaConnection( - count=2, - edges=[ - ModelReplicaEdge(node=mock_model_replica_1, cursor="replica-cursor-1"), - ModelReplicaEdge(node=mock_model_replica_2, cursor="replica-cursor-2"), - ], + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.get_replicas_by_deployment_id.wait_for_complete( + GetReplicasByDeploymentIdAction( + deployment_id=self._deployment_id, + ) + ) + + nodes = [ModelReplica.from_dataclass(data) for data in result.data] + + edges = [Edge(node=node, cursor=str(node.id)) for node in nodes] + + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, ) + return ModelReplicaConnection(count=len(nodes), edges=edges, page_info=page_info) + @strawberry.type(description="Added in 25.13.0") class ScalingRule: - auto_scaling_rules: list[AutoScalingRule] + _deployment_id: strawberry.Private[UUID] + _scaling_rule_ids: strawberry.Private[list[UUID]] + + @strawberry.field + async def auto_scaling_rules( + parent: strawberry.Parent["ScalingRule"], info: Info[StrawberryGQLContext] + ) -> list[AutoScalingRule]: + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.get_auto_scaling_rules_by_deployment_id.wait_for_complete( + GetAutoScalingRulesByDeploymentIdAction(deployment_id=UUID(str(parent._deployment_id))) + ) + + return [AutoScalingRule.from_dataclass(rule) for rule in result.data] + + @classmethod + def from_dataclass(cls, deployment_id: NodeID, scaling_rule_ids: list[UUID]) -> "ScalingRule": + return cls( + _deployment_id=deployment_id, + _scaling_rule_ids=scaling_rule_ids, + ) @strawberry.type(description="Added in 25.13.0") class ModelDeploymentMetadata: + _project_id: strawberry.Private[UUID] + _domain_name: strawberry.Private[str] name: str status: DeploymentStatus tags: list[str] - project: Project - domain: Domain created_at: datetime updated_at: datetime + @strawberry.field + async def project(self, info: Info[StrawberryGQLContext]) -> Project: + project_global_id = to_global_id( + GroupNode, self._project_id, is_target_graphene_object=True + ) + return Project(id=ID(project_global_id)) + + @strawberry.field + async def domain(self, info: Info[StrawberryGQLContext]) -> Domain: + domain_global_id = to_global_id( + DomainNode, self._domain_name, is_target_graphene_object=True + ) + return Domain(id=ID(domain_global_id)) + + @classmethod + def from_dataclass(cls, data: ModelDeploymentMetadataInfo) -> "ModelDeploymentMetadata": + return cls( + name=data.name, + status=DeploymentStatus(data.status), + tags=data.tags, + _project_id=data.project_id, + _domain_name=data.domain_name, + created_at=data.created_at, + updated_at=data.updated_at, + ) + @strawberry.type(description="Added in 25.13.0") class ModelDeploymentNetworkAccess: + _deployment_id: strawberry.Private[UUID] + _access_token_ids: strawberry.Private[Optional[list[UUID]]] endpoint_url: Optional[str] = None preferred_domain_name: Optional[str] = None open_to_public: bool = False - access_tokens: AccessTokenConnection + + @strawberry.field + async def access_tokens(self, info: Info[StrawberryGQLContext]) -> AccessTokenConnection: + """Resolve access tokens using dataloader.""" + access_token_loader = DataLoader( + apartial(AccessToken.batch_load_by_deployment_ids, info.context) + ) + token_nodes: list[AccessToken] = await access_token_loader.load(self._deployment_id) + + edges = [ + AccessTokenEdge(node=token_node, cursor=str(token_node.id)) + for token_node in token_nodes + ] + + return AccessTokenConnection( + count=len(edges), + edges=edges, + page_info=PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ), + ) + + @classmethod + def from_dataclass( + cls, data: DeploymentNetworkSpec, deployment_id: NodeID + ) -> "ModelDeploymentNetworkAccess": + return cls( + _deployment_id=deployment_id, + _access_token_ids=data.access_token_ids, + endpoint_url=data.url, + preferred_domain_name=data.preferred_domain_name, + open_to_public=data.open_to_public, + ) # Main ModelDeployment Type @@ -137,10 +269,35 @@ class ModelDeployment(Node): metadata: ModelDeploymentMetadata network_access: ModelDeploymentNetworkAccess revision: Optional[ModelRevision] = None - scaling_rule: ScalingRule - replica_state: ReplicaState default_deployment_strategy: DeploymentStrategy - created_user: User + _revision_history_ids: strawberry.Private[list[UUID]] + _replica_state_data: strawberry.Private[ReplicaStateData] + _created_user_id: strawberry.Private[UUID] + _scaling_rule_ids: strawberry.Private[list[UUID]] + + @strawberry.field + async def created_user(self, info: Info[StrawberryGQLContext]) -> User: + user_global_id = to_global_id( + UserNode, self._created_user_id, is_target_graphene_object=True + ) + return User(id=strawberry.ID(user_global_id)) + + @strawberry.field + async def scaling_rule(self, info: Info[StrawberryGQLContext]) -> ScalingRule: + return ScalingRule( + _deployment_id=self.id, + _scaling_rule_ids=self._scaling_rule_ids, + ) + + @strawberry.field + async def replica_state(self, info: Info[StrawberryGQLContext]) -> ReplicaState: + _, deployment_id = resolve_global_id(self.id) + + return ReplicaState( + _deployment_id=UUID(deployment_id), + desired_replica_count=self._replica_state_data.desired_replica_count, + _replica_ids=self._replica_state_data.replica_ids, + ) @strawberry.field async def revision_history( @@ -155,12 +312,75 @@ async def revision_history( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelRevisionConnection: + """Resolve revision history with dataloader.""" + replica_loader = DataLoader(apartial(ModelRevision.batch_load_by_ids, info.context)) + revisions: list[ModelRevision] = await replica_loader.load(self._revision_history_ids) + + edges = [ + ModelRevisionEdge(node=revision, cursor=str(revision.id)) for revision in revisions + ] + return ModelRevisionConnection( - count=2, - edges=[ - ModelRevisionEdge(node=mock_model_revision_1, cursor="rev-cursor-1"), - ModelRevisionEdge(node=mock_model_revision_2, cursor="rev-cursor-2"), - ], + count=len(edges), + edges=edges, + page_info=PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ), + ) + + @classmethod + async def batch_load_by_ids( + cls, ctx: StrawberryGQLContext, deployment_ids: Sequence[UUID] + ) -> list["ModelDeployment"]: + """Batch load deployments by their IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + model_deployments = [] + + for deployment_id in deployment_ids: + action_result = await processor.get_deployment.wait_for_complete( + GetDeploymentAction(deployment_id=deployment_id) + ) + model_deployments.append(action_result.data) + + return [cls.from_dataclass(deployment) for deployment in model_deployments if deployment] + + @classmethod + def from_dataclass( + cls, + data: ModelDeploymentData, + ) -> "ModelDeployment": + metadata = ModelDeploymentMetadata( + name=data.metadata.name, + status=DeploymentStatus(data.metadata.status), + tags=data.metadata.tags, + _project_id=data.metadata.project_id, + _domain_name=data.metadata.domain_name, + created_at=data.metadata.created_at, + updated_at=data.metadata.updated_at, + ) + + return cls( + id=ID(str(data.id)), + metadata=metadata, + network_access=ModelDeploymentNetworkAccess.from_dataclass( + data.network_access, ID(str(data.id)) + ), + revision=ModelRevision.from_dataclass(data.revision) if data.revision else None, + default_deployment_strategy=DeploymentStrategy( + type=DeploymentStrategyType(data.default_deployment_strategy) + ), + _created_user_id=uuid4(), + _revision_history_ids=data.revision_history_ids, + _scaling_rule_ids=data.scaling_rule_ids, + _replica_state_data=data.replica_state, ) @@ -226,6 +446,12 @@ class ModelDeploymentNetworkAccessInput: preferred_domain_name: Optional[str] = None open_to_public: bool = False + def to_network_spec(self) -> DeploymentNetworkSpec: + return DeploymentNetworkSpec( + open_to_public=self.open_to_public, + preferred_domain_name=self.preferred_domain_name, + ) + @strawberry.input(description="Added in 25.13.0") class DeploymentStrategyInput: @@ -240,6 +466,27 @@ class CreateModelDeploymentInput: desired_replica_count: int initial_revision: CreateModelRevisionInput + def to_creator(self) -> NewDeploymentCreator: + # TODO: Need to check the name generation logic + name = self.metadata.name or f"deployment-{uuid4().hex[:8]}" + tag = ",".join(self.metadata.tags) if self.metadata.tags else None + metadata_for_creator = DeploymentMetadata( + name=name, + domain=self.metadata.domain_name, + project=UUID(str(self.metadata.project_id)), + resource_group=self.initial_revision.resource_config.resource_group.name, + created_user=uuid4(), + session_owner=uuid4(), + created_at=None, + tag=tag, + ) + return NewDeploymentCreator( + metadata=metadata_for_creator, + replica_spec=ReplicaSpec(replica_count=self.desired_replica_count), + network=self.network_access.to_network_spec(), + model_revision=self.initial_revision.to_model_revision_creator(), + ) + @strawberry.input(description="Added in 25.13.0") class UpdateModelDeploymentInput: @@ -252,143 +499,28 @@ class UpdateModelDeploymentInput: name: Optional[str] = None preferred_domain_name: Optional[str] = None + def to_modifier(self) -> NewDeploymentModifier: + strategy_type = None + if self.default_deployment_strategy is not None: + strategy_type = CommonDeploymentStrategy(self.default_deployment_strategy.type) + return NewDeploymentModifier( + open_to_public=OptionalState[bool].from_graphql(self.open_to_public), + tags=OptionalState[list[str]].from_graphql(self.tags), + default_deployment_strategy=OptionalState[CommonDeploymentStrategy].from_graphql( + strategy_type + ), + active_revision_id=OptionalState[UUID].from_graphql(UUID(self.active_revision_id)), + desired_replica_count=OptionalState[int].from_graphql(self.desired_replica_count), + name=OptionalState[str].from_graphql(self.name), + preferred_domain_name=TriState[str].from_graphql(self.preferred_domain_name), + ) + @strawberry.input(description="Added in 25.13.0") class DeleteModelDeploymentInput: id: ID -# TODO: After implementing the actual logic, remove these mock objects -# Mock Model Deployments -mock_model_deployment_id_1 = "8c3105c3-3a02-42e3-aa00-6923cdcd114c" -mock_created_user_id_1 = "9a41b189-72fa-4265-afe8-04172ec5d26b" -mock_model_deployment_1 = ModelDeployment( - id=UUID(mock_model_deployment_id_1), - metadata=ModelDeploymentMetadata( - name="Llama 3.8B Instruct", - status=DeploymentStatus.READY, - tags=["production", "llm", "chat", "instruct"], - created_at=datetime.now() - timedelta(days=30), - updated_at=datetime.now() - timedelta(hours=2), - project=mock_project, - domain=mock_domain, - ), - network_access=ModelDeploymentNetworkAccess( - endpoint_url="https://api.backend.ai/models/dep-001", - preferred_domain_name="llama-3-8b.models.backend.ai", - open_to_public=True, - access_tokens=AccessTokenConnection( - count=5, - edges=[ - AccessTokenEdge(node=mock_access_token_1, cursor="token-cursor-1"), - AccessTokenEdge(node=mock_access_token_2, cursor="token-cursor-2"), - AccessTokenEdge(node=mock_access_token_3, cursor="token-cursor-3"), - AccessTokenEdge(node=mock_access_token_4, cursor="token-cursor-4"), - AccessTokenEdge(node=mock_access_token_5, cursor="token-cursor-5"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="token-cursor-1", - end_cursor="token-cursor-5", - ), - ), - ), - revision=mock_model_revision_1, - scaling_rule=ScalingRule(auto_scaling_rules=[mock_scaling_rule_1, mock_scaling_rule_2]), - replica_state=ReplicaState( - desired_replica_count=3, - _replica_ids=[mock_model_replica_1.id, mock_model_replica_2.id, mock_model_replica_3.id], - ), - default_deployment_strategy=DeploymentStrategy(type=DeploymentStrategyType.ROLLING), - created_user=User(id=mock_user_id), -) - -mock_model_deployment_id_2 = "5f839a95-17bd-43b0-a029-a132aa60ae71" -mock_created_user_id_2 = "75994553-fa63-4464-9398-67b6b96c8d11" -mock_model_deployment_2 = ModelDeployment( - id=UUID(mock_model_deployment_id_2), - metadata=ModelDeploymentMetadata( - name="Mistral 7B v0.3", - status=DeploymentStatus.READY, - tags=["staging", "llm", "experimental"], - created_at=datetime.now() - timedelta(days=20), - updated_at=datetime.now() - timedelta(days=1), - project=mock_project, - domain=mock_domain, - ), - network_access=ModelDeploymentNetworkAccess( - endpoint_url="https://api.backend.ai/models/dep-002", - preferred_domain_name="mistral-7b.models.backend.ai", - open_to_public=False, - access_tokens=AccessTokenConnection( - count=2, - edges=[ - AccessTokenEdge(node=mock_access_token_1, cursor="token-cursor-1"), - AccessTokenEdge(node=mock_access_token_2, cursor="token-cursor-2"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="token-cursor-1", - end_cursor="token-cursor-5", - ), - ), - ), - revision=mock_model_revision_3, - scaling_rule=ScalingRule(auto_scaling_rules=[]), - replica_state=ReplicaState( - desired_replica_count=1, - _replica_ids=[mock_model_replica_3.id], - ), - default_deployment_strategy=DeploymentStrategy(type=DeploymentStrategyType.BLUE_GREEN), - created_user=User(id=mock_user_id), -) - -mock_model_deployment_id_3 = "d040c413-a5df-4292-a5f4-0e0d85f7a1d4" -mock_created_user_id_3 = "640b0af8-8140-4e58-8ca4-96daba325be8" -mock_model_deployment_3 = ModelDeployment( - id=UUID(mock_model_deployment_id_3), - metadata=ModelDeploymentMetadata( - name="Gemma 2.9B", - status=DeploymentStatus.STOPPED, - project=mock_project, - domain=mock_domain, - tags=["development", "llm", "testing"], - created_at=datetime.now() - timedelta(days=15), - updated_at=datetime.now() - timedelta(days=7), - ), - network_access=ModelDeploymentNetworkAccess( - endpoint_url=None, - preferred_domain_name=None, - open_to_public=False, - access_tokens=AccessTokenConnection( - count=4, - edges=[ - AccessTokenEdge(node=mock_access_token_1, cursor="token-cursor-1"), - AccessTokenEdge(node=mock_access_token_2, cursor="token-cursor-2"), - AccessTokenEdge(node=mock_access_token_3, cursor="token-cursor-3"), - AccessTokenEdge(node=mock_access_token_4, cursor="token-cursor-4"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="token-cursor-1", - end_cursor="token-cursor-4", - ), - ), - ), - revision=None, - scaling_rule=ScalingRule(auto_scaling_rules=[]), - replica_state=ReplicaState( - desired_replica_count=0, - _replica_ids=[], - ), - default_deployment_strategy=DeploymentStrategy(type=DeploymentStrategyType.BLUE_GREEN), - created_user=User(id=mock_user_id), -) - - ModelDeploymentEdge = Edge[ModelDeployment] @@ -413,13 +545,26 @@ async def resolve_deployments( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelDeploymentConnection: - return ModelDeploymentConnection( - count=3, - edges=[ - ModelDeploymentEdge(node=mock_model_deployment_1, cursor="deployment-cursor-1"), - ModelDeploymentEdge(node=mock_model_deployment_2, cursor="deployment-cursor-2"), - ModelDeploymentEdge(node=mock_model_deployment_3, cursor="deployment-cursor-3"), - ], + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await processor.list_deployments.wait_for_complete( + ListDeploymentsAction(pagination=PaginationOptions()) + ) + edges = [] + for deployment in action_result.data: + edges.append( + ModelDeploymentEdge( + node=ModelDeployment.from_dataclass(deployment), cursor=str(deployment.id) + ) + ) + + # Mock pagination info for demonstration purposes + connection = ModelDeploymentConnection( + count=action_result.total_count, + edges=edges, page_info=PageInfo( has_next_page=False, has_previous_page=False, @@ -427,6 +572,7 @@ async def resolve_deployments( end_cursor="deployment-cursor-3", ), ) + return connection # Resolvers @@ -458,18 +604,32 @@ async def deployments( @strawberry.field(description="Added in 25.13.0") -async def deployment(id: ID) -> Optional[ModelDeployment]: +async def deployment(id: ID, info: Info[StrawberryGQLContext]) -> Optional[ModelDeployment]: """Get a specific deployment by ID.""" - return mock_model_deployment_1 + _, deployment_id = resolve_global_id(id) + dataloader = DataLoader(apartial(ModelDeployment.batch_load_by_ids, info.context)) + deployment: list[ModelDeployment] = await dataloader.load(deployment_id) + + return deployment[0] @strawberry.mutation(description="Added in 25.13.0") async def create_model_deployment( input: CreateModelDeploymentInput, info: Info[StrawberryGQLContext] -) -> CreateModelDeploymentPayload: +) -> "CreateModelDeploymentPayload": """Create a new model deployment.""" - # Create a dummy deployment for placeholder - return CreateModelDeploymentPayload(deployment=mock_model_deployment_1) + + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.create_deployment.wait_for_complete( + CreateDeploymentAction(creator=input.to_creator()) + ) + + return CreateModelDeploymentPayload(deployment=ModelDeployment.from_dataclass(result.data)) @strawberry.mutation(description="Added in 25.13.0") @@ -477,8 +637,18 @@ async def update_model_deployment( input: UpdateModelDeploymentInput, info: Info[StrawberryGQLContext] ) -> UpdateModelDeploymentPayload: """Update an existing model deployment.""" - # Create a dummy deployment for placeholder - return UpdateModelDeploymentPayload(deployment=mock_model_deployment_1) + _, deployment_id = resolve_global_id(input.id) + deployment_processor = info.context.processors.deployment + if deployment_processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await deployment_processor.update_deployment.wait_for_complete( + UpdateDeploymentAction(deployment_id=UUID(deployment_id), modifier=input.to_modifier()) + ) + return UpdateModelDeploymentPayload( + deployment=ModelDeployment.from_dataclass(action_result.data) + ) @strawberry.mutation(description="Added in 25.13.0") @@ -486,7 +656,16 @@ async def delete_model_deployment( input: DeleteModelDeploymentInput, info: Info[StrawberryGQLContext] ) -> DeleteModelDeploymentPayload: """Delete a model deployment.""" - return DeleteModelDeploymentPayload(id=ID(str(uuid4()))) + _, deployment_id = resolve_global_id(input.id) + deployment_processor = info.context.processors.deployment + if deployment_processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + _ = await deployment_processor.destroy_deployment.wait_for_complete( + DestroyDeploymentAction(endpoint_id=UUID(deployment_id)) + ) + return DeleteModelDeploymentPayload(id=input.id) @strawberry.subscription(description="Added in 25.13.0") @@ -494,10 +673,10 @@ async def deployment_status_changed( deployment_id: ID, info: Info[StrawberryGQLContext] ) -> AsyncGenerator[DeploymentStatusChangedPayload, None]: """Subscribe to deployment status changes.""" - deployment = [mock_model_deployment_1, mock_model_deployment_2, mock_model_deployment_3] - - for dep in deployment: - yield DeploymentStatusChangedPayload(deployment=dep) + # Mock implementation + # In real implementation, this would yield artifacts when status changes + if False: # Placeholder to make this a generator + yield DeploymentStatusChangedPayload(deployment_id=deployment_id) @strawberry.input(description="Added in 25.13.0") @@ -516,4 +695,13 @@ class SyncReplicaPayload: async def sync_replicas( input: SyncReplicaInput, info: Info[StrawberryGQLContext] ) -> SyncReplicaPayload: + _, deployment_id = resolve_global_id(input.model_deployment_id) + deployment_processor = info.context.processors.deployment + if deployment_processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + await deployment_processor.sync_replicas.wait_for_complete( + SyncReplicaAction(deployment_id=UUID(deployment_id)) + ) return SyncReplicaPayload(success=True) diff --git a/src/ai/backend/manager/api/gql/model_deployment/model_replica.py b/src/ai/backend/manager/api/gql/model_deployment/model_replica.py index f68a6263cee..6fa6261e554 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/model_replica.py +++ b/src/ai/backend/manager/api/gql/model_deployment/model_replica.py @@ -1,23 +1,40 @@ -from datetime import datetime, timedelta +from collections.abc import Sequence +from datetime import datetime from enum import StrEnum -from typing import AsyncGenerator, Optional, cast -from uuid import UUID, uuid4 +from typing import AsyncGenerator, Optional +from uuid import UUID import strawberry +from aiotools import apartial from strawberry import ID, Info +from strawberry.dataloader import DataLoader from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo from ai.backend.common.data.model_deployment.types import ActivenessStatus as CommonActivenessStatus from ai.backend.common.data.model_deployment.types import LivenessStatus as CommonLivenessStatus from ai.backend.common.data.model_deployment.types import ReadinessStatus as CommonReadinessStatus -from ai.backend.manager.api.gql.base import JSONString, OrderDirection +from ai.backend.common.exception import ModelDeploymentUnavailableError +from ai.backend.manager.api.gql.base import ( + JSONString, + OrderDirection, + resolve_global_id, + to_global_id, +) from ai.backend.manager.api.gql.session import Session from ai.backend.manager.api.gql.types import StrawberryGQLContext -from ai.backend.manager.models.gql_relay import AsyncNode +from ai.backend.manager.data.deployment.types import ModelReplicaData +from ai.backend.manager.models.gql_models.session import ComputeSessionNode +from ai.backend.manager.services.deployment.actions.get_replicas_by_deployment_id import ( + GetReplicasByDeploymentIdAction, +) +from ai.backend.manager.services.deployment.actions.get_replicas_by_revision_id import ( + GetReplicasByRevisionIdAction, +) +from ai.backend.manager.services.deployment.actions.list_replicas import ListReplicasAction +from ai.backend.manager.types import PaginationOptions from .model_revision import ( ModelRevision, - mock_model_revision_1, ) ReadinessStatus = strawberry.enum( @@ -83,16 +100,8 @@ class ReplicaOrderBy: @strawberry.type(description="Added in 25.13.0") class ModelReplica(Node): id: NodeID - revision: ModelRevision _session_id: strawberry.Private[UUID] - - @strawberry.field( - description="The session ID associated with the replica. This can be null right after replica creation." - ) - async def session(self, info: Info[StrawberryGQLContext]) -> "Session": - session_global_id = AsyncNode.to_global_id("ComputeSessionNode", self._session_id) - return Session(id=ID(session_global_id)) - + _revision_id: strawberry.Private[UUID] readiness_status: ReadinessStatus = strawberry.field( description="This represents whether the replica has been checked and its health state.", ) @@ -111,6 +120,79 @@ async def session(self, info: Info[StrawberryGQLContext]) -> "Session": description='live statistics of the routing node. e.g. "live_stat": "{\\"cpu_util\\": {\\"current\\": \\"7.472\\", \\"capacity\\": \\"1000\\", \\"pct\\": \\"0.75\\", \\"unit_hint\\": \\"percent\\"}}"' ) + @strawberry.field( + description="The session ID associated with the replica. This can be null right after replica creation." + ) + async def session(self, info: Info[StrawberryGQLContext]) -> "Session": + session_global_id = to_global_id( + ComputeSessionNode, self._session_id, is_target_graphene_object=True + ) + return Session(id=ID(session_global_id)) + + @strawberry.field + async def revision(self, info: Info[StrawberryGQLContext]) -> ModelRevision: + """Resolve revision using dataloader.""" + revision_loader = DataLoader(apartial(ModelRevision.batch_load_by_ids, info.context)) + revision: list[ModelRevision] = await revision_loader.load(self._revision_id) + return revision[0] + + @classmethod + def from_dataclass(cls, data: ModelReplicaData) -> "ModelReplica": + return cls( + id=ID(str(data.id)), + _revision_id=data.revision_id, + _session_id=data.session_id, + readiness_status=ReadinessStatus(data.readiness_status), + liveness_status=LivenessStatus(data.liveness_status), + activeness_status=ActivenessStatus(data.activeness_status), + weight=data.weight, + detail=JSONString.serialize(data.detail), + created_at=data.created_at, + live_stat=JSONString.serialize(data.live_stat), + ) + + @classmethod + async def batch_load_by_deployment_ids( + cls, ctx: StrawberryGQLContext, deployment_ids: Sequence[UUID] + ) -> list["ModelReplica"]: + """Batch load replicas by their IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + replicas = [] + + for deployment_id in deployment_ids: + action_result = await processor.get_replicas_by_deployment_id.wait_for_complete( + GetReplicasByDeploymentIdAction(deployment_id=deployment_id) + ) + replicas.extend(action_result.data) + + return [cls.from_dataclass(data) for data in replicas] + + @classmethod + async def batch_load_by_revision_ids( + cls, ctx: StrawberryGQLContext, revision_ids: Sequence[UUID] + ) -> list["ModelReplica"]: + """Batch load replicas by their revision IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + replicas = [] + + for revision_id in revision_ids: + action_result = await processor.get_replicas_by_revision_id.wait_for_complete( + GetReplicasByRevisionIdAction(revision_id=revision_id) + ) + replicas.extend(action_result.data) + + return [cls.from_dataclass(data) for data in replicas] + ModelReplicaEdge = Edge[ModelReplica] @@ -123,62 +205,19 @@ def __init__(self, *args, count: int, **kwargs): super().__init__(*args, **kwargs) self.count = count + @classmethod + def from_dataclass(cls, replicas_data: list[ModelReplicaData]) -> "ModelReplicaConnection": + nodes = [ModelReplica.from_dataclass(data) for data in replicas_data] + edges = [ModelReplicaEdge(node=node, cursor=str(node.id)) for node in nodes] -# Mock Model Replicas -mock_model_replica_1 = ModelReplica( - id=UUID("b62f9890-228a-40c9-a614-63387805b9a7"), - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.HEALTHY, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.ACTIVE, - weight=1, - detail=cast( - JSONString, - '{"type": "creation_success", "message": "Model replica created successfully", "status": "operational"}', - ), - created_at=datetime.now() - timedelta(days=5), - live_stat=cast( - JSONString, - '{"requests": 1523, "latency_ms": 187, "tokens_per_second": 42.5}', - ), -) - - -mock_model_replica_2 = ModelReplica( - id=UUID("7562e9d4-a368-4e28-9092-65eb91534bac"), - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.HEALTHY, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.ACTIVE, - weight=2, - detail=cast( - JSONString, - '{"type": "creation_success", "message": "Model replica created successfully", "status": "operational"}', - ), - created_at=datetime.now() - timedelta(days=5), - live_stat=cast( - JSONString, - '{"requests": 1456, "latency_ms": 195, "tokens_per_second": 41.2}', - ), -) + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) -mock_model_replica_3 = ModelReplica( - id=UUID("2a2388ea-a312-422a-b77e-0e0b61c48145"), - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.UNHEALTHY, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.INACTIVE, - weight=0, - detail=cast( - JSONString, - '{"type": "creation_failed", "errors": [{"src": "", "name": "InvalidAPIParameters", "repr": ""}]}', - ), - created_at=datetime.now() - timedelta(days=2), - live_stat=cast(JSONString, '{"requests": 0, "latency_ms": 0, "tokens_per_second": 0}'), -) + return cls(count=len(nodes), edges=edges, page_info=page_info) @strawberry.type(description="Added in 25.13.0") @@ -189,19 +228,10 @@ class ReplicaStatusChangedPayload: @strawberry.field(description="Added in 25.13.0") async def replica(id: ID, info: Info[StrawberryGQLContext]) -> Optional[ModelReplica]: """Get a specific replica by ID.""" - - return ModelReplica( - id=id, - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.NOT_CHECKED, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.ACTIVE, - weight=1, - detail=cast(JSONString, "{}"), - created_at=datetime.now() - timedelta(days=2), - live_stat=cast(JSONString, '{"requests": 0, "latency_ms": 0, "tokens_per_second": 0}'), - ) + _, replica_id = resolve_global_id(id) + replica_loader = DataLoader(apartial(ModelReplica.batch_load_by_revision_ids, info.context)) + replicas: list[ModelReplica] = await replica_loader.load(UUID(replica_id)) + return replicas[0] async def resolve_replicas( @@ -215,18 +245,29 @@ async def resolve_replicas( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelReplicaConnection: + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await processor.list_replicas.wait_for_complete( + ListReplicasAction(pagination=PaginationOptions()) + ) + edges = [] + for replica_data in action_result.data: + node = ModelReplica.from_dataclass(replica_data) + edge = ModelReplicaEdge(node=node, cursor=str(node.id)) + edges.append(edge) + + # Mock pagination info for demonstration purposes return ModelReplicaConnection( - count=3, - edges=[ - ModelReplicaEdge(node=mock_model_replica_1, cursor="replica-cursor-1"), - ModelReplicaEdge(node=mock_model_replica_2, cursor="replica-cursor-2"), - ModelReplicaEdge(node=mock_model_replica_3, cursor="replica-cursor-3"), - ], + count=len(edges), + edges=edges, page_info=PageInfo( has_next_page=False, has_previous_page=False, - start_cursor="replica-cursor-1", - end_cursor="replica-cursor-3", + start_cursor=str(edges[0].node.id) if edges else None, + end_cursor=str(edges[-1].node.id) if edges else None, ), ) @@ -261,7 +302,5 @@ async def replica_status_changed( revision_id: ID, ) -> AsyncGenerator[ReplicaStatusChangedPayload, None]: """Subscribe to replica status changes.""" - replicas = [mock_model_replica_1, mock_model_replica_2, mock_model_replica_3] - - for replica in replicas: + if False: # Replace with actual subscription logic yield ReplicaStatusChangedPayload(replica=replica) diff --git a/src/ai/backend/manager/api/gql/model_deployment/model_revision.py b/src/ai/backend/manager/api/gql/model_deployment/model_revision.py index 5de9d72feee..faba8330581 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/model_revision.py +++ b/src/ai/backend/manager/api/gql/model_deployment/model_revision.py @@ -1,15 +1,29 @@ +from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from decimal import Decimal from enum import Enum, StrEnum +from pathlib import PurePosixPath from typing import Any, Optional, cast from uuid import UUID, uuid4 import strawberry +from aiotools import apartial from strawberry import ID, Info +from strawberry.dataloader import DataLoader from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo from strawberry.scalars import JSON -from ai.backend.manager.api.gql.base import JSONString, OrderDirection, StringFilter +from ai.backend.common.exception import ModelDeploymentUnavailableError +from ai.backend.common.types import ClusterMode as CommonClusterMode +from ai.backend.common.types import MountPermission as CommonMountPermission +from ai.backend.common.types import RuntimeVariant +from ai.backend.manager.api.gql.base import ( + JSONString, + OrderDirection, + StringFilter, + resolve_global_id, + to_global_id, +) from ai.backend.manager.api.gql.image import ( Image, ) @@ -23,14 +37,47 @@ VFolder, mock_extra_mount_1, mock_extra_mount_2, - mock_vfolder_id, ) -from ai.backend.manager.data.model_deployment.inference_runtime_config import ( +from ai.backend.manager.data.deployment.creator import ModelRevisionCreator, VFolderMountsCreator +from ai.backend.manager.data.deployment.inference_runtime_config import ( MOJORuntimeConfig, NVDIANIMRuntimeConfig, SGLangRuntimeConfig, VLLMRuntimeConfig, ) +from ai.backend.manager.data.deployment.types import ( + ClusterConfigData, + ExecutionSpec, + ModelMountConfigData, + ModelRevisionData, + ModelRuntimeConfigData, + MountInfo, + ResourceConfigData, + ResourceSpec, +) +from ai.backend.manager.data.image.types import ImageIdentifier +from ai.backend.manager.models.gql_models.image import ImageNode +from ai.backend.manager.models.gql_models.scaling_group import ScalingGroupNode +from ai.backend.manager.models.gql_models.vfolder import VirtualFolderNode +from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import ( + AddModelRevisionAction, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_id import ( + GetRevisionByIdAction, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revisions_by_deployment_id import ( + GetRevisionsByDeploymentIdAction, +) +from ai.backend.manager.services.deployment.actions.model_revision.list_revisions import ( + ListRevisionsAction, +) +from ai.backend.manager.types import PaginationOptions + +MountPermission = strawberry.enum( + CommonMountPermission, + name="MountPermission", + description="Added in 25.13.0. This enum represents the permission level for a mounted volume. It can be ro, rw, wd", +) @strawberry.enum(description="Added in 25.13.0") @@ -41,10 +88,25 @@ class ClusterMode(StrEnum): @strawberry.type(description="Added in 25.13.0") class ModelMountConfig: - vfolder: VFolder + _vfolder_id: strawberry.Private[UUID] mount_destination: str definition_path: str + @strawberry.field + async def vfolder(self, info: Info[StrawberryGQLContext]) -> VFolder: + vfolder_global_id = to_global_id( + VirtualFolderNode, self._vfolder_id, is_target_graphene_object=True + ) + return VFolder(id=ID(vfolder_global_id)) + + @classmethod + def from_dataclass(cls, data: ModelMountConfigData) -> "ModelMountConfig": + return cls( + _vfolder_id=data.vfolder_id, + mount_destination=data.mount_destination, + definition_path=data.definition_path, + ) + @strawberry.type(description="Added in 25.13.0") class ModelRuntimeConfig: @@ -55,10 +117,18 @@ class ModelRuntimeConfig: default=None, ) + @classmethod + def from_dataclass(cls, data: ModelRuntimeConfigData) -> "ModelRuntimeConfig": + return cls( + runtime_variant=data.runtime_variant, + inference_runtime_config=data.inference_runtime_config, + environ=JSONString.serialize(data.environ) if data.environ else None, + ) + @strawberry.type(description="Added in 25.13.0") class ResourceConfig: - resource_group: ResourceGroup + _resource_group_name: strawberry.Private[str] resource_slots: JSONString = strawberry.field( description='Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\\"cpu\\": \\"1\\", \\"mem\\": \\"1073741824\\", \\"cuda.device\\": \\"0\\"}"' ) @@ -67,29 +137,108 @@ class ResourceConfig: default=None, ) + @strawberry.field + def resource_group(self) -> "ResourceGroup": + """Resolves the federated ResourceGroup.""" + global_id = to_global_id( + ScalingGroupNode, self._resource_group_name, is_target_graphene_object=True + ) + return ResourceGroup(id=ID(global_id)) + + @classmethod + def from_dataclass(cls, data: ResourceConfigData) -> "ResourceConfig": + return cls( + _resource_group_name=data.resource_group_name, + resource_slots=JSONString.from_resource_slot(data.resource_slot), + resource_opts=JSONString.serialize(data.resource_opts), + ) + @strawberry.type(description="Added in 25.13.0") class ClusterConfig: mode: ClusterMode size: int + @classmethod + def from_dataclass(cls, data: ClusterConfigData) -> "ClusterConfig": + return cls( + mode=ClusterMode(data.mode.name), + size=data.size, + ) + @strawberry.type(description="Added in 25.13.0") class ModelRevision(Node): + _image_id: strawberry.Private[UUID] id: NodeID name: str - cluster_config: ClusterConfig resource_config: ResourceConfig - model_runtime_config: ModelRuntimeConfig model_mount_config: ModelMountConfig extra_mounts: ExtraVFolderMountConnection - - image: Image - created_at: datetime + @strawberry.field + async def image(self, info: Info[StrawberryGQLContext]) -> Image: + image_global_id = to_global_id(ImageNode, self._image_id, is_target_graphene_object=True) + return Image(id=ID(image_global_id)) + + @classmethod + def from_dataclass(cls, data: ModelRevisionData) -> "ModelRevision": + return cls( + id=ID(str(data.id)), + name=data.name, + cluster_config=ClusterConfig.from_dataclass(data.cluster_config), + resource_config=ResourceConfig.from_dataclass(data.resource_config), + model_runtime_config=ModelRuntimeConfig.from_dataclass(data.model_runtime_config), + model_mount_config=ModelMountConfig.from_dataclass(data.model_mount_config), + extra_mounts=ExtraVFolderMountConnection.from_dataclass(data.extra_vfolder_mounts), + _image_id=data.image_id, + created_at=data.created_at, + ) + + @classmethod + async def batch_load_by_ids( + cls, ctx: StrawberryGQLContext, revision_ids: Sequence[UUID] + ) -> list["ModelRevision"]: + """Batch load revisions by their IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + revisions = [] + + for revision_id in revision_ids: + action_result = await processor.get_revision_by_id.wait_for_complete( + GetRevisionByIdAction(revision_id=revision_id) + ) + revisions.append(action_result.data) + + return [cls.from_dataclass(revision) for revision in revisions if revision] + + @classmethod + async def batch_load_by_deployment_ids( + cls, ctx: StrawberryGQLContext, deployment_ids: Sequence[UUID] + ) -> list["ModelRevision"]: + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + revisions = [] + + for deployment_id in deployment_ids: + action_result = await processor.get_revisions_by_deployment_id.wait_for_complete( + GetRevisionsByDeploymentIdAction(deployment_id=deployment_id) + ) + revisions.extend(action_result.data) + + return [cls.from_dataclass(revision) for revision in revisions if revision] + # Filter and Order Types @strawberry.input(description="Added in 25.13.0") @@ -118,12 +267,6 @@ class ModelRevisionOrderBy: # TODO: After implementing the actual logic, remove these mock objects # Mock Model Revisions - - -def _generate_random_name() -> str: - return f"revision-{uuid4()}" - - mock_inference_runtime_config = { "tp_size": 2, "pp_size": 4, @@ -137,124 +280,6 @@ def _generate_random_name() -> str: "tool_call_parser": "granite", "reasoning_parser": "deepseek_r1", } -mock_image_global_id = ID("SW1hZ2VOb2RlOjQwMWZjYjM4LTkwMWYtNDdjYS05YmJjLWQyMjUzYjk4YTZhMA==") -mock_revision_id_1 = "d19f8f78-f308-45a9-ab7b-1c63346024fd" -mock_model_revision_1 = ModelRevision( - id=UUID(mock_revision_id_1), - name="llama-3-8b-instruct-v1.0", - cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant="custom", - inference_runtime_config=mock_inference_runtime_config, - environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "0"}'), - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="models/llama-3-8b/config.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=2, - edges=[ - ExtraVFolderMountEdge(node=mock_extra_mount_1, cursor="extra-mount-cursor-1"), - ExtraVFolderMountEdge(node=mock_extra_mount_2, cursor="extra-mount-cursor-2"), - ], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now() - timedelta(days=10), -) - -mock_revision_id_2 = "3c81bc63-24c1-4a8f-9ad2-8a19899690c3" -mock_model_revision_2 = ModelRevision( - id=UUID(mock_revision_id_2), - name="llama-3-8b-instruct-v1.1", - cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant="vllm", - inference_runtime_config=mock_inference_runtime_config, - environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "0,1"}'), - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="models/llama-3-8b/config.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=2, - edges=[ - ExtraVFolderMountEdge(node=mock_extra_mount_1, cursor="extra-mount-cursor-1"), - ExtraVFolderMountEdge(node=mock_extra_mount_2, cursor="extra-mount-cursor-2"), - ], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now() - timedelta(days=5), -) - - -mock_revision_id_3 = "86d1a714-b177-4851-897f-da36f306fe30" -mock_model_revision_3 = ModelRevision( - id=UUID(mock_revision_id_3), - name="mistral-7b-v0.3-initial", - cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant="vllm", - inference_runtime_config=mock_inference_runtime_config, - environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "2"}'), - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="models/mistral-7b/config.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=0, - edges=[], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now() - timedelta(days=20), -) # Payload Types @@ -331,6 +356,55 @@ class CreateModelRevisionInput: model_mount_config: ModelMountConfigInput extra_mounts: Optional[list[ExtraVFolderMountInput]] + def to_model_revision_creator(self) -> ModelRevisionCreator: + image_identifier = ImageIdentifier( + canonical=self.image.name, + architecture=self.image.architecture, + ) + + resource_spec = ResourceSpec( + cluster_mode=CommonClusterMode(self.cluster_config.mode), + cluster_size=self.cluster_config.size, + resource_slots=cast(Mapping[str, Any], self.resource_config.resource_slots), + resource_opts=cast(Mapping[str, Any] | None, self.resource_config.resource_opts), + ) + + extra_mounts = [] + if self.extra_mounts is not None: + extra_mounts = [ + MountInfo( + vfolder_id=UUID(str(extra_mount.vfolder_id)), + kernel_path=PurePosixPath( + extra_mount.mount_destination + if extra_mount.mount_destination is not None + else "" + ), + ) + for extra_mount in self.extra_mounts + ] + + mounts = VFolderMountsCreator( + model_vfolder_id=UUID(str(self.model_mount_config.vfolder_id)), + model_definition_path=self.model_mount_config.definition_path, + model_mount_destination=self.model_mount_config.mount_destination, + extra_mounts=extra_mounts, + ) + + execution_spec = ExecutionSpec( + environ=cast(Optional[dict[str, str]], self.model_runtime_config.environ), + runtime_variant=RuntimeVariant(self.model_runtime_config.runtime_variant), + inference_runtime_config=cast( + Optional[dict[str, Any]], self.model_runtime_config.inference_runtime_config + ), + ) + + return ModelRevisionCreator( + image_identifier=image_identifier, + resource_spec=resource_spec, + mounts=mounts, + execution=execution_spec, + ) + @strawberry.input(description="Added in 25.13.0") class AddModelRevisionInput: @@ -343,6 +417,55 @@ class AddModelRevisionInput: model_mount_config: ModelMountConfigInput extra_mounts: Optional[list[ExtraVFolderMountInput]] + def to_model_revision_creator(self) -> ModelRevisionCreator: + image_identifier = ImageIdentifier( + canonical=self.image.name, + architecture=self.image.architecture, + ) + + resource_spec = ResourceSpec( + cluster_mode=CommonClusterMode(self.cluster_config.mode), + cluster_size=self.cluster_config.size, + resource_slots=cast(Mapping[str, Any], self.resource_config.resource_slots), + resource_opts=cast(Mapping[str, Any] | None, self.resource_config.resource_opts), + ) + + extra_mounts = [] + if self.extra_mounts is not None: + extra_mounts = [ + MountInfo( + vfolder_id=UUID(str(extra_mount.vfolder_id)), + kernel_path=PurePosixPath( + extra_mount.mount_destination + if extra_mount.mount_destination is not None + else "" + ), + ) + for extra_mount in self.extra_mounts + ] + + mounts = VFolderMountsCreator( + model_vfolder_id=UUID(str(self.model_mount_config.vfolder_id)), + model_definition_path=self.model_mount_config.definition_path, + model_mount_destination=self.model_mount_config.mount_destination, + extra_mounts=extra_mounts, + ) + + execution_spec = ExecutionSpec( + environ=cast(Optional[dict[str, str]], self.model_runtime_config.environ), + runtime_variant=RuntimeVariant(self.model_runtime_config.runtime_variant), + inference_runtime_config=cast( + Optional[dict[str, Any]], self.model_runtime_config.inference_runtime_config + ), + ) + + return ModelRevisionCreator( + image_identifier=image_identifier, + resource_spec=resource_spec, + mounts=mounts, + execution=execution_spec, + ) + ModelRevisionEdge = Edge[ModelRevision] @@ -355,6 +478,20 @@ def __init__(self, *args, count: int, **kwargs: Any): super().__init__(*args, **kwargs) self.count = count + @classmethod + def from_dataclass(cls, revisions_data: list[ModelRevisionData]) -> "ModelRevisionConnection": + nodes = [ModelRevision.from_dataclass(data) for data in revisions_data] + edges = [ModelRevisionEdge(node=node, cursor=str(node.id)) for node in nodes] + + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) + + return cls(count=len(nodes), edges=edges, page_info=page_info) + @strawberry.field( description="Added in 25.13.0. Get JSON Schema for inference runtime configuration" @@ -400,18 +537,32 @@ async def resolve_revisions( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelRevisionConnection: - # Implement the logic to resolve the revisions based on the provided filters and pagination - return ModelRevisionConnection( - count=3, - edges=[ - ModelRevisionEdge(node=mock_model_revision_1, cursor="revision-cursor-1"), - ModelRevisionEdge(node=mock_model_revision_2, cursor="revision-cursor-2"), - ModelRevisionEdge(node=mock_model_revision_3, cursor="revision-cursor-3"), - ], + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await processor.list_revisions.wait_for_complete( + ListRevisionsAction(pagination=PaginationOptions()) + ) + edges = [] + for revision in action_result.data: + edges.append( + ModelRevisionEdge(node=ModelRevision.from_dataclass(revision), cursor=str(revision.id)) + ) + + # Mock pagination info for demonstration purposes + connection = ModelRevisionConnection( + count=action_result.total_count, + edges=edges, page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None + has_next_page=False, + has_previous_page=False, + start_cursor="revision-cursor-1", + end_cursor="revision-cursor-3", ), ) + return connection @strawberry.field(description="Added in 25.13.0") @@ -443,96 +594,75 @@ async def revisions( @strawberry.field(description="Added in 25.13.0") async def revision(id: ID, info: Info[StrawberryGQLContext]) -> ModelRevision: """Get a specific revision by ID.""" - return mock_model_revision_1 + _, revision_id = resolve_global_id(id) + revision_loader = DataLoader(apartial(ModelRevision.batch_load_by_ids, info.context)) + revision: list[ModelRevision] = await revision_loader.load(revision_id) + return revision[0] @strawberry.mutation(description="Added in 25.13.0") +async def add_model_revision( + input: AddModelRevisionInput, info: Info[StrawberryGQLContext] +) -> AddModelRevisionPayload: + """Add a model revision to a deployment.""" + + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailableError( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.add_model_revision.wait_for_complete( + AddModelRevisionAction(input.to_model_revision_creator()) + ) + + return AddModelRevisionPayload(revision=ModelRevision.from_dataclass(result.revision)) + + +@strawberry.mutation( + description="Added in 25.13.0. Create model revision which is not attached to any deployment." +) async def create_model_revision( input: CreateModelRevisionInput, info: Info[StrawberryGQLContext] ) -> CreateModelRevisionPayload: """Create a new model revision.""" - revision = ModelRevision( - id=UUID("4cc91efb-7297-47ec-80c4-6e9c4378ae8b"), - name=_generate_random_name(), - cluster_config=ClusterConfig( - mode=ClusterMode.SINGLE_NODE, - size=1, - ), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant=input.model_runtime_config.runtime_variant, - inference_runtime_config=input.model_runtime_config.inference_runtime_config, - environ=None, - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="model.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=0, - edges=[], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None + return CreateModelRevisionPayload( + revision=ModelRevision( + id=UUID("d19f8f78-f308-45a9-ab7b-1c63346024fd"), + name="llama-3-8b-instruct-v1.0", + cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), + resource_config=ResourceConfig( + _resource_group_name="default", + resource_slots=cast( + JSONString, + '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', + ), + resource_opts=cast( + JSONString, + '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', + ), ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now(), - ) - return CreateModelRevisionPayload(revision=revision) - - -@strawberry.mutation(description="Added in 25.13.0") -async def add_model_revision( - input: AddModelRevisionInput, info: Info[StrawberryGQLContext] -) -> AddModelRevisionPayload: - """Add a model revision to a deployment.""" - revision = ModelRevision( - id=UUID("dda405f0-6463-45c4-a5ca-3721cc8d730c"), - name=_generate_random_name(), - cluster_config=ClusterConfig( - mode=ClusterMode.SINGLE_NODE, - size=1, - ), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', + model_runtime_config=ModelRuntimeConfig( + runtime_variant="custom", + inference_runtime_config=mock_inference_runtime_config, + environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "0"}'), ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', + model_mount_config=ModelMountConfig( + _vfolder_id=uuid4(), + mount_destination="/models", + definition_path="models/llama-3-8b/config.yaml", ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant=input.model_runtime_config.runtime_variant, - inference_runtime_config=input.model_runtime_config.inference_runtime_config, - environ=None, - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="model.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=0, - edges=[], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None + extra_mounts=ExtraVFolderMountConnection( + count=2, + edges=[ + ExtraVFolderMountEdge(node=mock_extra_mount_1, cursor="extra-mount-cursor-1"), + ExtraVFolderMountEdge(node=mock_extra_mount_2, cursor="extra-mount-cursor-2"), + ], + page_info=PageInfo( + has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None + ), ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now(), + _image_id=uuid4(), + created_at=datetime.now() - timedelta(days=10), + ) ) - return AddModelRevisionPayload(revision=revision) diff --git a/src/ai/backend/manager/api/gql/model_deployment/routing.py b/src/ai/backend/manager/api/gql/model_deployment/routing.py new file mode 100644 index 00000000000..5417f905e0d --- /dev/null +++ b/src/ai/backend/manager/api/gql/model_deployment/routing.py @@ -0,0 +1,50 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +import strawberry +from strawberry.relay import Connection, Node, NodeID, PageInfo +from strawberry.relay.types import NodeIterableType +from strawberry.types import Info + +from ai.backend.manager.api.gql.base import JSONString + + +@strawberry.type(description="Added in 25.13.0") +class RoutingNode(Node): + id: NodeID + routing_id: UUID + endpoint_url: str + session_id: UUID + status: str + traffic_ratio: float + created_at: datetime + live_stat: JSONString = strawberry.field( + description='live statistics of the routing node. e.g. "live_stat": "{\\"cpu_util\\": {\\"current\\": \\"7.472\\", \\"capacity\\": \\"1000\\", \\"pct\\": \\"0.75\\", \\"unit_hint\\": \\"percent\\"}}"' + ) + + +@strawberry.type(description="Added in 25.13.0") +class RoutingNodeConnection(Connection[RoutingNode]): + @classmethod + def resolve_connection( + cls, + nodes: NodeIterableType[RoutingNode], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + max_results: Optional[int] = None, + **kwargs, + ) -> "RoutingNodeConnection": + return cls( + edges=[], + page_info=PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=None, + end_cursor=None, + ), + ) diff --git a/src/ai/backend/manager/api/gql/vfolder.py b/src/ai/backend/manager/api/gql/vfolder.py index 56b2411a23b..ebe61b56670 100644 --- a/src/ai/backend/manager/api/gql/vfolder.py +++ b/src/ai/backend/manager/api/gql/vfolder.py @@ -1,9 +1,13 @@ from typing import Any -from uuid import uuid4 +from uuid import UUID, uuid4 import strawberry from strawberry import ID, Info -from strawberry.relay import Connection, Edge, Node, NodeID +from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo + +from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.types import ExtraVFolderMountData +from ai.backend.manager.models.gql_relay import AsyncNode @strawberry.federation.type(keys=["id"], name="VirtualFolderNode", extend=True) @@ -15,14 +19,25 @@ def resolve_reference(cls, id: ID, info: Info) -> "VFolder": return cls(id=id) -mock_vfolder_id = ID("VmlydHVhbEZvbGRlck5vZGU6YmEzMzE5ZGQtMTFmZC00Yjk4LTkzNGMtNjUxYTQ4YTVmMzM0") - - @strawberry.type class ExtraVFolderMount(Node): - id: NodeID + id: NodeID[str] mount_destination: str - vfolder: VFolder + _vfolder_id: strawberry.Private[UUID] + + @strawberry.field + async def vfolder(self, info: Info[StrawberryGQLContext]) -> VFolder: + vfolder_global_id = AsyncNode.to_global_id("VirtualFolderNode", self._vfolder_id) + return VFolder(id=ID(vfolder_global_id)) + + @classmethod + def from_dataclass(cls, data: ExtraVFolderMountData) -> "ExtraVFolderMount": + return cls( + # TODO: fix id generation logic + id=ID(f"{data.vfolder_id}:{data.mount_destination}"), + mount_destination=data.mount_destination, + _vfolder_id=data.vfolder_id, + ) ExtraVFolderMountEdge = Edge[ExtraVFolderMount] @@ -36,15 +51,30 @@ def __init__(self, *args, count: int, **kwargs: Any): super().__init__(*args, **kwargs) self.count = count + @classmethod + def from_dataclass( + cls, mounts_data: list[ExtraVFolderMountData] + ) -> "ExtraVFolderMountConnection": + nodes = [ExtraVFolderMount.from_dataclass(data) for data in mounts_data] + edges = [Edge(node=node, cursor=str(node.id)) for node in nodes] + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) + + return cls(count=len(nodes), edges=edges, page_info=page_info) + mock_extra_mount_1 = ExtraVFolderMount( - id=uuid4(), - vfolder=VFolder(id=mock_vfolder_id), + id=str(uuid4()), + _vfolder_id=uuid4(), mount_destination="/extra_models/model1", ) mock_extra_mount_2 = ExtraVFolderMount( - id=uuid4(), - vfolder=VFolder(id=mock_vfolder_id), + id=str(uuid4()), + _vfolder_id=uuid4(), mount_destination="/extra_models/model2", ) diff --git a/src/ai/backend/manager/api/service.py b/src/ai/backend/manager/api/service.py index e4917b2fe9c..cd1d8789192 100644 --- a/src/ai/backend/manager/api/service.py +++ b/src/ai/backend/manager/api/service.py @@ -48,9 +48,9 @@ ResourceSpec, ) from ai.backend.manager.data.image.types import ImageIdentifier -from ai.backend.manager.services.deployment.actions.create_deployment import ( - CreateDeploymentAction, - CreateDeploymentActionResult, +from ai.backend.manager.services.deployment.actions.create_legacy_deployment import ( + CreateLegacyDeploymentAction, + CreateLegacyDeploymentActionResult, ) from ai.backend.manager.services.deployment.actions.destroy_deployment import ( DestroyDeploymentAction, @@ -735,7 +735,7 @@ async def create(request: web.Request, params: NewServiceRequestModel) -> ServeI and root_ctx.processors.deployment is not None ): # Create deployment using the new deployment controller - deployment_action = CreateDeploymentAction( + deployment_action = CreateLegacyDeploymentAction( creator=DeploymentCreator( metadata=DeploymentMetadata( name=params.service_name, @@ -756,8 +756,8 @@ async def create(request: web.Request, params: NewServiceRequestModel) -> ServeI ), ) ) - deployment_result: CreateDeploymentActionResult = ( - await root_ctx.processors.deployment.create_deployment.wait_for_complete( + deployment_result: CreateLegacyDeploymentActionResult = ( + await root_ctx.processors.deployment.create_legacy_deployment.wait_for_complete( deployment_action ) ) diff --git a/src/ai/backend/manager/data/deployment/access_token.py b/src/ai/backend/manager/data/deployment/access_token.py new file mode 100644 index 00000000000..d0545e3ebba --- /dev/null +++ b/src/ai/backend/manager/data/deployment/access_token.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from datetime import datetime +from uuid import UUID + + +@dataclass +class ModelDeploymentAccessTokenCreator: + model_deployment_id: UUID + valid_until: datetime diff --git a/src/ai/backend/manager/data/deployment/creator.py b/src/ai/backend/manager/data/deployment/creator.py index ee50a0c8eaa..72d97ec3cf7 100644 --- a/src/ai/backend/manager/data/deployment/creator.py +++ b/src/ai/backend/manager/data/deployment/creator.py @@ -1,15 +1,44 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional from uuid import UUID from ai.backend.manager.data.deployment.types import ( DeploymentMetadata, DeploymentNetworkSpec, + ExecutionSpec, ModelRevisionSpec, + MountInfo, + MountMetadata, ReplicaSpec, + ResourceSpec, ) from ai.backend.manager.data.image.types import ImageIdentifier +@dataclass +class VFolderMountsCreator: + model_vfolder_id: UUID + model_definition_path: Optional[str] = None + model_mount_destination: str = "/models" + extra_mounts: list[MountInfo] = field(default_factory=list) + + +@dataclass +class ModelRevisionCreator: + image_identifier: ImageIdentifier + resource_spec: ResourceSpec + mounts: VFolderMountsCreator + execution: ExecutionSpec + + def to_revision_spec(self, mount_metadata: MountMetadata) -> ModelRevisionSpec: + return ModelRevisionSpec( + image_identifier=self.image_identifier, + resource_spec=self.resource_spec, + mounts=mount_metadata, + execution=self.execution, + ) + + @dataclass class DeploymentCreator: metadata: DeploymentMetadata @@ -37,3 +66,11 @@ def project(self) -> UUID: def name(self) -> str: """Get the deployment name from metadata.""" return self.metadata.name + + +@dataclass +class NewDeploymentCreator: + metadata: DeploymentMetadata + replica_spec: ReplicaSpec + network: DeploymentNetworkSpec + model_revision: ModelRevisionCreator diff --git a/src/ai/backend/manager/data/model_deployment/inference_runtime_config.py b/src/ai/backend/manager/data/deployment/inference_runtime_config.py similarity index 100% rename from src/ai/backend/manager/data/model_deployment/inference_runtime_config.py rename to src/ai/backend/manager/data/deployment/inference_runtime_config.py diff --git a/src/ai/backend/manager/data/deployment/modifier.py b/src/ai/backend/manager/data/deployment/modifier.py index 48de852dd6b..a6b4d8ca7ea 100644 --- a/src/ai/backend/manager/data/deployment/modifier.py +++ b/src/ai/backend/manager/data/deployment/modifier.py @@ -2,6 +2,7 @@ from typing import Any, Optional, override from uuid import UUID +from ai.backend.common.data.model_deployment.types import DeploymentStrategy from ai.backend.manager.types import OptionalState, PartialModifier, TriState @@ -89,3 +90,32 @@ def fields_to_update(self) -> dict[str, Any]: if self.model_revision: to_update.update(self.model_revision.fields_to_update()) return to_update + + +@dataclass +class NewDeploymentModifier(PartialModifier): + name: OptionalState[str] = field(default_factory=OptionalState[str].nop) + tags: OptionalState[list[str]] = field(default_factory=OptionalState[list[str]].nop) + desired_replica_count: OptionalState[int] = field(default_factory=OptionalState[int].nop) + open_to_public: OptionalState[bool] = field(default_factory=OptionalState[bool].nop) + preferred_domain_name: TriState[str] = field(default_factory=TriState[str].nop) + default_deployment_strategy: OptionalState[DeploymentStrategy] = field( + default_factory=OptionalState[DeploymentStrategy].nop + ) + active_revision_id: OptionalState[UUID] = field( + default_factory=OptionalState[UUID].nop + ) # TODO: Check if TriState is more appropriate + + @override + def fields_to_update(self) -> dict[str, Any]: + to_update: dict[str, Any] = {} + self.name.update_dict(to_update, "name") + tag = self.tags.optional_value() + if tag is not None: + to_update["tags"] = ",".join(tag) + self.desired_replica_count.update_dict(to_update, "desired_replica_count") + self.open_to_public.update_dict(to_update, "open_to_public") + self.preferred_domain_name.update_dict(to_update, "preferred_domain_name") + self.default_deployment_strategy.update_dict(to_update, "default_deployment_strategy") + self.active_revision_id.update_dict(to_update, "current_revision_id") + return to_update diff --git a/src/ai/backend/manager/data/deployment/scale.py b/src/ai/backend/manager/data/deployment/scale.py index b255163f6da..40b082835dc 100644 --- a/src/ai/backend/manager/data/deployment/scale.py +++ b/src/ai/backend/manager/data/deployment/scale.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from datetime import datetime +from decimal import Decimal from typing import Optional from uuid import UUID from ai.backend.common.types import AutoScalingMetricComparator, AutoScalingMetricSource +# Dataclasses for auto scaling rules used in Model Service (legacy) @dataclass class AutoScalingCondition: metric_source: AutoScalingMetricSource @@ -35,3 +37,31 @@ class AutoScalingRule: action: AutoScalingAction created_at: datetime last_triggered_at: Optional[datetime] + + +# Dataclasses for auto scaling rules used in Model Deployment +@dataclass +class ModelDeploymentAutoScalingRuleCreator: + model_deployment_id: UUID + metric_source: AutoScalingMetricSource + metric_name: str + min_threshold: Optional[Decimal] + max_threshold: Optional[Decimal] + step_size: int + time_window: int + min_replicas: Optional[int] + max_replicas: Optional[int] + + +@dataclass +class ModelDeploymentAutoScalingRule: + id: UUID + model_deployment_id: UUID + metric_source: AutoScalingMetricSource + metric_name: str + min_threshold: Optional[Decimal] + max_threshold: Optional[Decimal] + step_size: int + time_window: int + min_replicas: Optional[int] + max_replicas: Optional[int] diff --git a/src/ai/backend/manager/data/deployment/scale_modifier.py b/src/ai/backend/manager/data/deployment/scale_modifier.py index b8cf20db48a..feefc3da662 100644 --- a/src/ai/backend/manager/data/deployment/scale_modifier.py +++ b/src/ai/backend/manager/data/deployment/scale_modifier.py @@ -1,43 +1,12 @@ from dataclasses import dataclass, field -from datetime import datetime +from decimal import Decimal from typing import Any, Optional, override -from uuid import UUID from ai.backend.common.types import AutoScalingMetricComparator, AutoScalingMetricSource from ai.backend.manager.types import OptionalState, PartialModifier -@dataclass -class AutoScalingCondition: - metric_source: AutoScalingMetricSource - metric_name: str - threshold: str - comparator: AutoScalingMetricComparator - - -@dataclass -class AutoScalingAction: - step_size: int - cooldown_seconds: int - min_replicas: Optional[int] = None - max_replicas: Optional[int] = None - - -@dataclass -class AutoScalingRuleCreator: - condition: AutoScalingCondition - action: AutoScalingAction - - -@dataclass -class AutoScalingRule: - id: UUID - condition: AutoScalingCondition - action: AutoScalingAction - created_at: datetime - last_triggered_at: Optional[datetime] - - +# Dataclasses for auto scaling rules used in Model Service (legacy) @dataclass class AutoScalingConditionModifier(PartialModifier): metric_source: OptionalState[AutoScalingMetricSource] = field( @@ -91,3 +60,31 @@ def fields_to_update(self) -> dict[str, Any]: to_update.update(self.condition_modifier.fields_to_update()) to_update.update(self.action_modifier.fields_to_update()) return to_update + + +# Dataclasses for auto scaling rules used in Model Deployment +@dataclass +class ModelDeploymentAutoScalingRuleModifier(PartialModifier): + metric_source: OptionalState[AutoScalingMetricSource] = field( + default_factory=OptionalState[AutoScalingMetricSource].nop + ) + metric_name: OptionalState[str] = field(default_factory=OptionalState[str].nop) + min_threshold: OptionalState[Decimal] = field(default_factory=OptionalState[Decimal].nop) + max_threshold: OptionalState[Decimal] = field(default_factory=OptionalState[Decimal].nop) + step_size: OptionalState[int] = field(default_factory=OptionalState[int].nop) + time_window: OptionalState[int] = field(default_factory=OptionalState[int].nop) + min_replicas: OptionalState[int] = field(default_factory=OptionalState[int].nop) + max_replicas: OptionalState[int] = field(default_factory=OptionalState[int].nop) + + @override + def fields_to_update(self) -> dict[str, Any]: + to_update: dict[str, Any] = {} + self.metric_source.update_dict(to_update, "metric_source") + self.metric_name.update_dict(to_update, "metric_name") + self.min_threshold.update_dict(to_update, "min_threshold") + self.max_threshold.update_dict(to_update, "max_threshold") + self.step_size.update_dict(to_update, "step_size") + self.time_window.update_dict(to_update, "time_window") + self.min_replicas.update_dict(to_update, "min_replicas") + self.max_replicas.update_dict(to_update, "max_replicas") + return to_update diff --git a/src/ai/backend/manager/data/deployment/types.py b/src/ai/backend/manager/data/deployment/types.py index b1fcac603d6..a4aa742bd3f 100644 --- a/src/ai/backend/manager/data/deployment/types.py +++ b/src/ai/backend/manager/data/deployment/types.py @@ -2,13 +2,29 @@ from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime +from decimal import Decimal from functools import lru_cache +from pathlib import PurePosixPath from typing import Any, Optional -from uuid import UUID +from uuid import UUID, uuid4 import yarl -from ai.backend.common.types import ClusterMode, RuntimeVariant, SessionId, VFolderMount +from ai.backend.common.data.model_deployment.types import ( + ActivenessStatus, + DeploymentStrategy, + LivenessStatus, + ModelDeploymentStatus, + ReadinessStatus, +) +from ai.backend.common.types import ( + AutoScalingMetricSource, + ClusterMode, + ResourceSlot, + RuntimeVariant, + SessionId, + VFolderMount, +) from ai.backend.manager.data.deployment.scale import AutoScalingRule from ai.backend.manager.data.image.types import ImageIdentifier @@ -95,6 +111,12 @@ class MountSpec: mount_options: Mapping[UUID, dict[str, Any]] +@dataclass +class MountInfo: + vfolder_id: UUID + kernel_path: PurePosixPath + + @dataclass class MountMetadata: model_vfolder_id: UUID @@ -142,6 +164,7 @@ class ExecutionSpec: environ: Optional[dict[str, str]] = None runtime_variant: RuntimeVariant = RuntimeVariant.CUSTOM callback_url: Optional[yarl.URL] = None + inference_runtime_config: Optional[Mapping[str, Any]] = None @dataclass @@ -155,7 +178,9 @@ class ModelRevisionSpec: @dataclass class DeploymentNetworkSpec: open_to_public: bool + access_token_ids: Optional[list[UUID]] = None url: Optional[str] = None + preferred_domain_name: Optional[str] = None @dataclass @@ -218,3 +243,177 @@ class DeploymentInfoWithAutoScalingRules: deployment_info: DeploymentInfo rules: list[AutoScalingRule] = field(default_factory=list) + + +@dataclass +class ModelDeploymentAutoScalingRuleData: + id: UUID + model_deployment_id: UUID + metric_source: AutoScalingMetricSource + metric_name: str + min_threshold: Optional[Decimal] + max_threshold: Optional[Decimal] + step_size: int + time_window: int + min_replicas: Optional[int] + max_replicas: Optional[int] + created_at: datetime + last_triggered_at: datetime + + +@dataclass +class ModelDeploymentAccessTokenData: + id: UUID + token: str + valid_until: datetime + created_at: datetime + + +@dataclass +class ModelReplicaData: + id: UUID + revision_id: UUID + session_id: UUID + readiness_status: ReadinessStatus + liveness_status: LivenessStatus + activeness_status: ActivenessStatus + weight: int + detail: dict[str, Any] + created_at: datetime + live_stat: dict[str, Any] + + +@dataclass +class ClusterConfigData: + mode: ClusterMode + size: int + + +@dataclass +class ResourceConfigData: + resource_group_name: str + resource_slot: ResourceSlot + resource_opts: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass +class ModelRuntimeConfigData: + runtime_variant: RuntimeVariant + inference_runtime_config: Optional[Mapping[str, Any]] = None + environ: Optional[dict[str, Any]] = None + + +@dataclass +class ModelMountConfigData: + vfolder_id: UUID + mount_destination: str + definition_path: str + + +@dataclass +class ExtraVFolderMountData: + vfolder_id: UUID + mount_destination: str + + +@dataclass +class ModelRevisionData: + id: UUID + name: str + cluster_config: ClusterConfigData + resource_config: ResourceConfigData + model_runtime_config: ModelRuntimeConfigData + model_mount_config: ModelMountConfigData + created_at: datetime + image_id: UUID + extra_vfolder_mounts: list[ExtraVFolderMountData] = field(default_factory=list) + + +@dataclass +class ModelDeploymentMetadataInfo: + name: str + status: ModelDeploymentStatus + tags: list[str] + project_id: UUID + domain_name: str + created_at: datetime + updated_at: datetime + + +@dataclass +class ReplicaStateData: + desired_replica_count: int + replica_ids: list[UUID] + + +@dataclass +class ModelDeploymentData: + id: UUID + metadata: ModelDeploymentMetadataInfo + network_access: DeploymentNetworkSpec + revision: Optional[ModelRevisionData] + revision_history_ids: list[UUID] + scaling_rule_ids: list[UUID] + replica_state: ReplicaStateData + default_deployment_strategy: DeploymentStrategy + created_user_id: UUID + access_token_ids: Optional[UUID] = None + + +mock_revision_data_1 = ModelRevisionData( + id=uuid4(), + name="test-revision", + cluster_config=ClusterConfigData( + mode=ClusterMode.SINGLE_NODE, + size=1, + ), + resource_config=ResourceConfigData( + resource_group_name="default", + resource_slot=ResourceSlot.from_json({"cpu": 1, "memory": 1024}), + ), + model_mount_config=ModelMountConfigData( + vfolder_id=uuid4(), + mount_destination="/model", + definition_path="model-definition.yaml", + ), + model_runtime_config=ModelRuntimeConfigData( + runtime_variant=RuntimeVariant.VLLM, + inference_runtime_config={"tp_size": 2, "max_length": 1024}, + ), + extra_vfolder_mounts=[ + ExtraVFolderMountData( + vfolder_id=uuid4(), + mount_destination="/var", + ), + ExtraVFolderMountData( + vfolder_id=uuid4(), + mount_destination="/example", + ), + ], + image_id=uuid4(), + created_at=datetime.now(), +) + +mock_revision_data_2 = ModelRevisionData( + id=uuid4(), + name="test-revision-2", + cluster_config=ClusterConfigData( + mode=ClusterMode.MULTI_NODE, + size=1, + ), + resource_config=ResourceConfigData( + resource_group_name="default", + resource_slot=ResourceSlot.from_json({"cpu": 1, "memory": 1024}), + ), + model_mount_config=ModelMountConfigData( + vfolder_id=uuid4(), + mount_destination="/model", + definition_path="model-definition.yaml", + ), + model_runtime_config=ModelRuntimeConfigData( + runtime_variant=RuntimeVariant.NIM, + inference_runtime_config={"tp_size": 2, "max_length": 1024}, + ), + image_id=uuid4(), + created_at=datetime.now(), +) diff --git a/src/ai/backend/manager/data/model_deployment/BUILD b/src/ai/backend/manager/data/model_deployment/BUILD deleted file mode 100644 index 73574424040..00000000000 --- a/src/ai/backend/manager/data/model_deployment/BUILD +++ /dev/null @@ -1 +0,0 @@ -python_sources(name="src") diff --git a/src/ai/backend/manager/data/scaling_group/types.py b/src/ai/backend/manager/data/scaling_group/types.py new file mode 100644 index 00000000000..a99fb3781d1 --- /dev/null +++ b/src/ai/backend/manager/data/scaling_group/types.py @@ -0,0 +1,20 @@ +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime +from typing import Any + + +@dataclass +class ScalingGroupData: + name: str + description: str + is_active: bool + is_public: bool + created_at: datetime + wsproxy_addr: str + wsproxy_api_token: str + driver: str + driver_opts: Mapping[str, Any] + scheduler: str + scheduler_opts: Mapping[str, Any] + use_host_network: bool diff --git a/src/ai/backend/manager/dto/context.py b/src/ai/backend/manager/dto/context.py index ae9ea07df10..db90edb2a81 100644 --- a/src/ai/backend/manager/dto/context.py +++ b/src/ai/backend/manager/dto/context.py @@ -18,3 +18,14 @@ class ProcessorsCtx(MiddlewareParam): async def from_request(cls, request: web.Request) -> Self: root_ctx: RootContext = request.app["_root.context"] return cls(processors=root_ctx.processors) + + +class RequestCtx(MiddlewareParam): + request: web.Request + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @override + @classmethod + async def from_request(cls, request: web.Request) -> Self: + return cls(request=request) diff --git a/src/ai/backend/manager/data/model_deployment/__init__.py b/src/ai/backend/manager/services/deployment/actions/access_token/__init__.py similarity index 100% rename from src/ai/backend/manager/data/model_deployment/__init__.py rename to src/ai/backend/manager/services/deployment/actions/access_token/__init__.py diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/base.py b/src/ai/backend/manager/services/deployment/actions/access_token/base.py new file mode 100644 index 00000000000..ab9c7607c69 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/access_token/base.py @@ -0,0 +1,10 @@ +from typing import override + +from ai.backend.manager.actions.action import BaseAction + + +class DeploymentAccessTokenBaseAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> str: + return "deployment_access_token" diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/create_access_token.py b/src/ai/backend/manager/services/deployment/actions/access_token/create_access_token.py new file mode 100644 index 00000000000..9b596a04039 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/access_token/create_access_token.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.access_token import ModelDeploymentAccessTokenCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentAccessTokenData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class CreateAccessTokenAction(DeploymentBaseAction): + creator: ModelDeploymentAccessTokenCreator + + @override + def entity_id(self) -> Optional[str]: + return str(self.creator.model_deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateAccessTokenActionResult(BaseActionResult): + data: ModelDeploymentAccessTokenData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/get_access_tokens_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/access_token/get_access_tokens_by_deployment_id.py new file mode 100644 index 00000000000..fe401662ee7 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/access_token/get_access_tokens_by_deployment_id.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelDeploymentAccessTokenData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class GetAccessTokensByDeploymentIdAction(DeploymentBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "read" + + +@dataclass +class GetAccessTokensByDeploymentIdActionResult(BaseActionResult): + data: list[ModelDeploymentAccessTokenData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/__init__.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/base.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/base.py new file mode 100644 index 00000000000..33c4781a994 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/base.py @@ -0,0 +1,10 @@ +from typing import override + +from ai.backend.manager.actions.action import BaseAction + + +class AutoScalingRuleBaseAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> str: + return "auto_scaling_rule" diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/create_auto_scaling_rule.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/create_auto_scaling_rule.py new file mode 100644 index 00000000000..0e68936e500 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/create_auto_scaling_rule.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.scale import ModelDeploymentAutoScalingRuleCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentAutoScalingRuleData +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class CreateAutoScalingRuleAction(AutoScalingRuleBaseAction): + creator: ModelDeploymentAutoScalingRuleCreator + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateAutoScalingRuleActionResult(BaseActionResult): + data: ModelDeploymentAutoScalingRuleData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/delete_auto_scaling_rule.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/delete_auto_scaling_rule.py new file mode 100644 index 00000000000..4c0c5d3f51c --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/delete_auto_scaling_rule.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class DeleteAutoScalingRuleAction(AutoScalingRuleBaseAction): + auto_scaling_rule_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.auto_scaling_rule_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "delete" + + +@dataclass +class DeleteAutoScalingRuleActionResult(BaseActionResult): + success: bool + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/get_auto_scaling_rule_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/get_auto_scaling_rule_by_deployment_id.py new file mode 100644 index 00000000000..b60a429ed91 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/get_auto_scaling_rule_by_deployment_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelDeploymentAutoScalingRuleData, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class GetAutoScalingRulesByDeploymentIdAction(AutoScalingRuleBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetAutoScalingRulesByDeploymentIdActionResult(BaseActionResult): + data: list[ModelDeploymentAutoScalingRuleData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/update_auto_scaling_rule.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/update_auto_scaling_rule.py new file mode 100644 index 00000000000..534a147d64a --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/update_auto_scaling_rule.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.scale_modifier import ModelDeploymentAutoScalingRuleModifier +from ai.backend.manager.data.deployment.types import ( + ModelDeploymentAutoScalingRuleData, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class UpdateAutoScalingRuleAction(AutoScalingRuleBaseAction): + auto_scaling_rule_id: UUID + modifier: ModelDeploymentAutoScalingRuleModifier + + @override + def entity_id(self) -> Optional[str]: + return str(self.auto_scaling_rule_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "update" + + +@dataclass +class UpdateAutoScalingRuleActionResult(BaseActionResult): + data: ModelDeploymentAutoScalingRuleData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/base.py b/src/ai/backend/manager/services/deployment/actions/base.py index 4cbfb6ff613..4b73f7009ee 100644 --- a/src/ai/backend/manager/services/deployment/actions/base.py +++ b/src/ai/backend/manager/services/deployment/actions/base.py @@ -1,13 +1,9 @@ -"""Base action for deployment service.""" - from typing import override from ai.backend.manager.actions.action import BaseAction class DeploymentBaseAction(BaseAction): - """Base action for deployment operations.""" - @override @classmethod def entity_type(cls) -> str: diff --git a/src/ai/backend/manager/services/deployment/actions/create_deployment.py b/src/ai/backend/manager/services/deployment/actions/create_deployment.py index bdfbbf2537f..3c491b471ba 100644 --- a/src/ai/backend/manager/services/deployment/actions/create_deployment.py +++ b/src/ai/backend/manager/services/deployment/actions/create_deployment.py @@ -4,16 +4,16 @@ from typing import Optional, override from ai.backend.manager.actions.action import BaseActionResult -from ai.backend.manager.data.deployment.creator import DeploymentCreator -from ai.backend.manager.data.deployment.types import DeploymentInfo +from ai.backend.manager.data.deployment.creator import NewDeploymentCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentData from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction @dataclass class CreateDeploymentAction(DeploymentBaseAction): - """Action to create a new deployment.""" + """Action to create a new deployment(Model Service).""" - creator: DeploymentCreator + creator: NewDeploymentCreator @override def entity_id(self) -> Optional[str]: @@ -27,7 +27,7 @@ def operation_type(cls) -> str: @dataclass class CreateDeploymentActionResult(BaseActionResult): - data: DeploymentInfo + data: ModelDeploymentData @override def entity_id(self) -> Optional[str]: diff --git a/src/ai/backend/manager/services/deployment/actions/create_legacy_deployment.py b/src/ai/backend/manager/services/deployment/actions/create_legacy_deployment.py new file mode 100644 index 00000000000..3632ce8730c --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/create_legacy_deployment.py @@ -0,0 +1,34 @@ +"""Action for creating legacy deployments(Model Service).""" + +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.creator import DeploymentCreator +from ai.backend.manager.data.deployment.types import DeploymentInfo +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class CreateLegacyDeploymentAction(DeploymentBaseAction): + """Action to create a new legacy deployment(Model Service).""" + + creator: DeploymentCreator + + @override + def entity_id(self) -> Optional[str]: + return None # New deployment doesn't have an ID yet + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateLegacyDeploymentActionResult(BaseActionResult): + data: DeploymentInfo + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/get_deployment.py b/src/ai/backend/manager/services/deployment/actions/get_deployment.py new file mode 100644 index 00000000000..9644c59ddaf --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/get_deployment.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelDeploymentData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class GetDeploymentAction(DeploymentBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "get_deployment" + + +@dataclass +class GetDeploymentActionResult(BaseActionResult): + data: ModelDeploymentData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/get_replicas_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/get_replicas_by_deployment_id.py new file mode 100644 index 00000000000..8782a4dee9b --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/get_replicas_by_deployment_id.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelReplicaData, +) +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class GetReplicasByDeploymentIdAction(DeploymentBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetReplicasByDeploymentIdActionResult(BaseActionResult): + data: list[ModelReplicaData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/get_replicas_by_revision_id.py b/src/ai/backend/manager/services/deployment/actions/get_replicas_by_revision_id.py new file mode 100644 index 00000000000..9a6f182f672 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/get_replicas_by_revision_id.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelReplicaData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class GetReplicasByRevisionIdAction(DeploymentBaseAction): + revision_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.revision_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "get_replicas_by_revision_id" + + +@dataclass +class GetReplicasByRevisionIdActionResult(BaseActionResult): + data: list[ModelReplicaData] + + @override + def entity_id(self) -> Optional[str]: + return None # This is a list operation for replicas diff --git a/src/ai/backend/manager/services/deployment/actions/list_deployments.py b/src/ai/backend/manager/services/deployment/actions/list_deployments.py new file mode 100644 index 00000000000..6c186ca3b15 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/list_deployments.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelDeploymentData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListDeploymentsAction(DeploymentBaseAction): + pagination: PaginationOptions + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_deployments" + + +@dataclass +class ListDeploymentsActionResult(BaseActionResult): + data: list[ModelDeploymentData] + # Note: Total number of deployments, this is not equals to len(data) + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/list_replicas.py b/src/ai/backend/manager/services/deployment/actions/list_replicas.py new file mode 100644 index 00000000000..43a8adc3a3f --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/list_replicas.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelReplicaData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListReplicasAction(DeploymentBaseAction): + pagination: PaginationOptions + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_deployments" + + +@dataclass +class ListReplicasActionResult(BaseActionResult): + data: list[ModelReplicaData] + # Note: Total number of replicas, this is not equals to len(data) + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/__init__.py b/src/ai/backend/manager/services/deployment/actions/model_revision/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/add_model_revision.py b/src/ai/backend/manager/services/deployment/actions/model_revision/add_model_revision.py new file mode 100644 index 00000000000..9f9d090000e --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/add_model_revision.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.creator import ModelRevisionCreator +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class AddModelRevisionAction(ModelRevisionBaseAction): + adder: ModelRevisionCreator + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class AddModelRevisionActionResult(BaseActionResult): + revision: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.revision.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/base.py b/src/ai/backend/manager/services/deployment/actions/model_revision/base.py new file mode 100644 index 00000000000..b60ebea80f5 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/base.py @@ -0,0 +1,10 @@ +from typing import override + +from ai.backend.manager.actions.action import BaseAction + + +class ModelRevisionBaseAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> str: + return "model_revision" diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_deployment_id.py new file mode 100644 index 00000000000..254de284e94 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_deployment_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelRevisionData, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionByDeploymentIdAction(ModelRevisionBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionByDeploymentIdActionResult(BaseActionResult): + data: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_id.py new file mode 100644 index 00000000000..9b6b7db51ed --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelRevisionData, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionByIdAction(ModelRevisionBaseAction): + revision_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.revision_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionByIdActionResult(BaseActionResult): + data: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_replica_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_replica_id.py new file mode 100644 index 00000000000..2045e012b24 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_replica_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelRevisionData, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionByReplicaIdAction(ModelRevisionBaseAction): + replica_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionByReplicaIdActionResult(BaseActionResult): + data: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revisions_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revisions_by_deployment_id.py new file mode 100644 index 00000000000..00ea0244b69 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revisions_by_deployment_id.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionsByDeploymentIdAction(ModelRevisionBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionsByDeploymentIdActionResult(BaseActionResult): + data: list[ModelRevisionData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/list_revisions.py b/src/ai/backend/manager/services/deployment/actions/model_revision/list_revisions.py new file mode 100644 index 00000000000..9e191320287 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/list_revisions.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListRevisionsAction(ModelRevisionBaseAction): + pagination: PaginationOptions + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_revisions" + + +@dataclass +class ListRevisionsActionResult(BaseActionResult): + data: list[ModelRevisionData] + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/sync_replicas.py b/src/ai/backend/manager/services/deployment/actions/sync_replicas.py new file mode 100644 index 00000000000..bfd5e20207b --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/sync_replicas.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class SyncReplicaAction(DeploymentBaseAction): + """Action to sync replicas for an existing deployment.""" + + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "sync_replicas" + + +@dataclass +class SyncReplicaActionResult(BaseActionResult): + success: bool + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/update_deployment.py b/src/ai/backend/manager/services/deployment/actions/update_deployment.py new file mode 100644 index 00000000000..58c0b0dbd8d --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/update_deployment.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.modifier import NewDeploymentModifier +from ai.backend.manager.data.deployment.types import ModelDeploymentData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class UpdateDeploymentAction(DeploymentBaseAction): + """Action to update an existing deployment.""" + + deployment_id: UUID + modifier: NewDeploymentModifier + + @override + def entity_id(self) -> Optional[str]: + return str(self.deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "update" + + +@dataclass +class UpdateDeploymentActionResult(BaseActionResult): + data: ModelDeploymentData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/processors.py b/src/ai/backend/manager/services/deployment/processors.py index 6d82cfd1163..22cb49a0ff8 100644 --- a/src/ai/backend/manager/services/deployment/processors.py +++ b/src/ai/backend/manager/services/deployment/processors.py @@ -1,36 +1,291 @@ """Deployment service processors for GraphQL API.""" -from typing import TYPE_CHECKING, override +from typing import Protocol, override from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.services.deployment.actions.access_token.create_access_token import ( + CreateAccessTokenAction, + CreateAccessTokenActionResult, +) +from ai.backend.manager.services.deployment.actions.access_token.get_access_tokens_by_deployment_id import ( + GetAccessTokensByDeploymentIdAction, + GetAccessTokensByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.create_auto_scaling_rule import ( + CreateAutoScalingRuleAction, + CreateAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.delete_auto_scaling_rule import ( + DeleteAutoScalingRuleAction, + DeleteAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.get_auto_scaling_rule_by_deployment_id import ( + GetAutoScalingRulesByDeploymentIdAction, + GetAutoScalingRulesByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.update_auto_scaling_rule import ( + UpdateAutoScalingRuleAction, + UpdateAutoScalingRuleActionResult, +) from ai.backend.manager.services.deployment.actions.create_deployment import ( CreateDeploymentAction, CreateDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.create_legacy_deployment import ( + CreateLegacyDeploymentAction, + CreateLegacyDeploymentActionResult, +) from ai.backend.manager.services.deployment.actions.destroy_deployment import ( DestroyDeploymentAction, DestroyDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.get_deployment import ( + GetDeploymentAction, + GetDeploymentActionResult, +) +from ai.backend.manager.services.deployment.actions.get_replicas_by_deployment_id import ( + GetReplicasByDeploymentIdAction, + GetReplicasByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.get_replicas_by_revision_id import ( + GetReplicasByRevisionIdAction, + GetReplicasByRevisionIdActionResult, +) +from ai.backend.manager.services.deployment.actions.list_deployments import ( + ListDeploymentsAction, + ListDeploymentsActionResult, +) +from ai.backend.manager.services.deployment.actions.list_replicas import ( + ListReplicasAction, + ListReplicasActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import ( + AddModelRevisionAction, + AddModelRevisionActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_deployment_id import ( + GetRevisionByDeploymentIdAction, + GetRevisionByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_id import ( + GetRevisionByIdAction, + GetRevisionByIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_replica_id import ( + GetRevisionByReplicaIdAction, + GetRevisionByReplicaIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revisions_by_deployment_id import ( + GetRevisionsByDeploymentIdAction, + GetRevisionsByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.list_revisions import ( + ListRevisionsAction, + ListRevisionsActionResult, +) +from ai.backend.manager.services.deployment.actions.sync_replicas import ( + SyncReplicaAction, + SyncReplicaActionResult, +) +from ai.backend.manager.services.deployment.actions.update_deployment import ( + UpdateDeploymentAction, + UpdateDeploymentActionResult, +) + + +class DeploymentServiceProtocol(Protocol): + async def create_deployment( + self, action: CreateDeploymentAction + ) -> CreateDeploymentActionResult: ... + + async def create_legacy_deployment( + self, action: CreateLegacyDeploymentAction + ) -> CreateLegacyDeploymentActionResult: ... + + async def update_deployment( + self, action: UpdateDeploymentAction + ) -> UpdateDeploymentActionResult: ... + + async def destroy_deployment( + self, action: DestroyDeploymentAction + ) -> DestroyDeploymentActionResult: ... + + async def create_auto_scaling_rule( + self, action: CreateAutoScalingRuleAction + ) -> CreateAutoScalingRuleActionResult: ... + + async def update_auto_scaling_rule( + self, action: UpdateAutoScalingRuleAction + ) -> UpdateAutoScalingRuleActionResult: ... + + async def delete_auto_scaling_rule( + self, action: DeleteAutoScalingRuleAction + ) -> DeleteAutoScalingRuleActionResult: ... + + async def create_access_token( + self, action: CreateAccessTokenAction + ) -> CreateAccessTokenActionResult: ... + + async def get_access_tokens_by_deployment_id( + self, action: GetAccessTokensByDeploymentIdAction + ) -> GetAccessTokensByDeploymentIdActionResult: ... + + async def sync_replicas(self, action: SyncReplicaAction) -> SyncReplicaActionResult: ... + + async def add_model_revision( + self, action: AddModelRevisionAction + ) -> AddModelRevisionActionResult: ... + + async def get_auto_scaling_rules_by_deployment_id( + self, action: GetAutoScalingRulesByDeploymentIdAction + ) -> GetAutoScalingRulesByDeploymentIdActionResult: ... + + async def get_replicas_by_deployment_id( + self, action: GetReplicasByDeploymentIdAction + ) -> GetReplicasByDeploymentIdActionResult: ... + + async def get_revision_by_deployment_id( + self, action: GetRevisionByDeploymentIdAction + ) -> GetRevisionByDeploymentIdActionResult: ... + + async def get_revision_by_replica_id( + self, action: GetRevisionByReplicaIdAction + ) -> GetRevisionByReplicaIdActionResult: ... + + async def get_revision_by_id( + self, action: GetRevisionByIdAction + ) -> GetRevisionByIdActionResult: ... + + async def get_deployment(self, action: GetDeploymentAction) -> GetDeploymentActionResult: ... + + async def get_revisions_by_deployment_id( + self, action: GetRevisionsByDeploymentIdAction + ) -> GetRevisionsByDeploymentIdActionResult: ... + + async def get_replicas_by_revision_id( + self, action: GetReplicasByRevisionIdAction + ) -> GetReplicasByRevisionIdActionResult: ... -if TYPE_CHECKING: - from ai.backend.manager.services.deployment.service import DeploymentService + async def list_replicas(self, action: ListReplicasAction) -> ListReplicasActionResult: ... + async def list_revisions(self, action: ListRevisionsAction) -> ListRevisionsActionResult: ... class DeploymentProcessors(AbstractProcessorPackage): """Processors for deployment operations.""" create_deployment: ActionProcessor[CreateDeploymentAction, CreateDeploymentActionResult] + update_deployment: ActionProcessor[UpdateDeploymentAction, UpdateDeploymentActionResult] destroy_deployment: ActionProcessor[DestroyDeploymentAction, DestroyDeploymentActionResult] + create_legacy_deployment: ActionProcessor[ + CreateLegacyDeploymentAction, CreateLegacyDeploymentActionResult + ] + create_auto_scaling_rule: ActionProcessor[ + CreateAutoScalingRuleAction, CreateAutoScalingRuleActionResult + ] + update_auto_scaling_rule: ActionProcessor[ + UpdateAutoScalingRuleAction, UpdateAutoScalingRuleActionResult + ] + delete_auto_scaling_rule: ActionProcessor[ + DeleteAutoScalingRuleAction, DeleteAutoScalingRuleActionResult + ] + create_access_token: ActionProcessor[CreateAccessTokenAction, CreateAccessTokenActionResult] + get_access_tokens_by_deployment_id: ActionProcessor[ + GetAccessTokensByDeploymentIdAction, GetAccessTokensByDeploymentIdActionResult + ] + sync_replicas: ActionProcessor[SyncReplicaAction, SyncReplicaActionResult] + add_model_revision: ActionProcessor[AddModelRevisionAction, AddModelRevisionActionResult] + get_auto_scaling_rules_by_deployment_id: ActionProcessor[ + GetAutoScalingRulesByDeploymentIdAction, GetAutoScalingRulesByDeploymentIdActionResult + ] + get_revision_by_id: ActionProcessor[GetRevisionByIdAction, GetRevisionByIdActionResult] + get_replicas_by_deployment_id: ActionProcessor[ + GetReplicasByDeploymentIdAction, GetReplicasByDeploymentIdActionResult + ] + get_revision_by_deployment_id: ActionProcessor[ + GetRevisionByDeploymentIdAction, GetRevisionByDeploymentIdActionResult + ] + get_revision_by_replica_id: ActionProcessor[ + GetRevisionByReplicaIdAction, GetRevisionByReplicaIdActionResult + ] + get_deployment: ActionProcessor[GetDeploymentAction, GetDeploymentActionResult] + list_deployments: ActionProcessor[ListDeploymentsAction, ListDeploymentsActionResult] + get_revisions_by_deployment_id: ActionProcessor[ + GetRevisionsByDeploymentIdAction, GetRevisionsByDeploymentIdActionResult + ] + get_replicas_by_revision_id: ActionProcessor[ + GetReplicasByRevisionIdAction, GetReplicasByRevisionIdActionResult + ] + list_replicas: ActionProcessor[ListReplicasAction, ListReplicasActionResult] + list_revisions: ActionProcessor[ListRevisionsAction, ListRevisionsActionResult] - def __init__(self, service: "DeploymentService", action_monitors: list[ActionMonitor]) -> None: - self.create_deployment = ActionProcessor(service.create, action_monitors) - self.destroy_deployment = ActionProcessor(service.destroy, action_monitors) + def __init__( + self, service: DeploymentServiceProtocol, action_monitors: list[ActionMonitor] + ) -> None: + self.create_auto_scaling_rule = ActionProcessor( + service.create_auto_scaling_rule, action_monitors + ) + self.update_auto_scaling_rule = ActionProcessor( + service.update_auto_scaling_rule, action_monitors + ) + self.delete_auto_scaling_rule = ActionProcessor( + service.delete_auto_scaling_rule, action_monitors + ) + self.create_deployment = ActionProcessor(service.create_deployment, action_monitors) + self.destroy_deployment = ActionProcessor(service.destroy_deployment, action_monitors) + self.update_deployment = ActionProcessor(service.update_deployment, action_monitors) + self.create_legacy_deployment = ActionProcessor( + service.create_legacy_deployment, action_monitors + ) + self.create_access_token = ActionProcessor(service.create_access_token, action_monitors) + self.get_access_tokens_by_deployment_id = ActionProcessor( + service.get_access_tokens_by_deployment_id, action_monitors + ) + self.sync_replicas = ActionProcessor(service.sync_replicas, action_monitors) + self.add_model_revision = ActionProcessor(service.add_model_revision, action_monitors) + self.get_auto_scaling_rules_by_deployment_id = ActionProcessor( + service.get_auto_scaling_rules_by_deployment_id, action_monitors + ) + self.get_replicas_by_deployment_id = ActionProcessor( + service.get_replicas_by_deployment_id, action_monitors + ) + self.get_revision_by_replica_id = ActionProcessor( + service.get_revision_by_replica_id, action_monitors + ) + self.get_revision_by_id = ActionProcessor(service.get_revision_by_id, action_monitors) + self.get_deployment = ActionProcessor(service.get_deployment, action_monitors) + self.get_revisions_by_deployment_id = ActionProcessor( + service.get_revisions_by_deployment_id, action_monitors + ) + self.get_replicas_by_revision_id = ActionProcessor( + service.get_replicas_by_revision_id, action_monitors + ) + self.list_replicas = ActionProcessor(service.list_replicas, action_monitors) + self.list_revisions = ActionProcessor(service.list_revisions, action_monitors) @override def supported_actions(self) -> list[ActionSpec]: return [ CreateDeploymentAction.spec(), DestroyDeploymentAction.spec(), + CreateAutoScalingRuleAction.spec(), + UpdateAutoScalingRuleAction.spec(), + UpdateDeploymentAction.spec(), + DeleteAutoScalingRuleAction.spec(), + CreateAccessTokenAction.spec(), + GetAccessTokensByDeploymentIdAction.spec(), + SyncReplicaAction.spec(), + AddModelRevisionAction.spec(), + GetAutoScalingRulesByDeploymentIdAction.spec(), + GetReplicasByDeploymentIdAction.spec(), + GetRevisionByDeploymentIdAction.spec(), + GetRevisionByReplicaIdAction.spec(), + GetRevisionByIdAction.spec(), + GetDeploymentAction.spec(), + GetRevisionsByDeploymentIdAction.spec(), + GetReplicasByRevisionIdAction.spec(), + ListRevisionsAction.spec(), + ListReplicasAction.spec(), + CreateLegacyDeploymentAction.spec(), ] diff --git a/src/ai/backend/manager/services/deployment/service.py b/src/ai/backend/manager/services/deployment/service.py index ee9f2745b35..94c40842815 100644 --- a/src/ai/backend/manager/services/deployment/service.py +++ b/src/ai/backend/manager/services/deployment/service.py @@ -1,16 +1,112 @@ """Deployment service for managing model deployments.""" import logging +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid4 +from ai.backend.common.data.model_deployment.types import ( + DeploymentStrategy, + ModelDeploymentStatus, +) +from ai.backend.common.types import ( + AutoScalingMetricSource, +) from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.data.deployment.types import ( + DeploymentNetworkSpec, + ModelDeploymentAccessTokenData, + ModelDeploymentAutoScalingRuleData, + ModelDeploymentData, + ModelDeploymentMetadataInfo, + ReplicaStateData, + mock_revision_data_1, + mock_revision_data_2, +) +from ai.backend.manager.services.deployment.actions.access_token.create_access_token import ( + CreateAccessTokenAction, + CreateAccessTokenActionResult, +) +from ai.backend.manager.services.deployment.actions.access_token.get_access_tokens_by_deployment_id import ( + GetAccessTokensByDeploymentIdAction, + GetAccessTokensByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.create_auto_scaling_rule import ( + CreateAutoScalingRuleAction, + CreateAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.delete_auto_scaling_rule import ( + DeleteAutoScalingRuleAction, + DeleteAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.get_auto_scaling_rule_by_deployment_id import ( + GetAutoScalingRulesByDeploymentIdAction, + GetAutoScalingRulesByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.update_auto_scaling_rule import ( + UpdateAutoScalingRuleAction, + UpdateAutoScalingRuleActionResult, +) from ai.backend.manager.services.deployment.actions.create_deployment import ( CreateDeploymentAction, CreateDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.create_legacy_deployment import ( + CreateLegacyDeploymentAction, + CreateLegacyDeploymentActionResult, +) from ai.backend.manager.services.deployment.actions.destroy_deployment import ( DestroyDeploymentAction, DestroyDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.get_deployment import ( + GetDeploymentAction, + GetDeploymentActionResult, +) +from ai.backend.manager.services.deployment.actions.get_replicas_by_deployment_id import ( + GetReplicasByDeploymentIdAction, + GetReplicasByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.get_replicas_by_revision_id import ( + GetReplicasByRevisionIdAction, + GetReplicasByRevisionIdActionResult, +) +from ai.backend.manager.services.deployment.actions.list_replicas import ( + ListReplicasAction, + ListReplicasActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import ( + AddModelRevisionAction, + AddModelRevisionActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_deployment_id import ( + GetRevisionByDeploymentIdAction, + GetRevisionByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_id import ( + GetRevisionByIdAction, + GetRevisionByIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_replica_id import ( + GetRevisionByReplicaIdAction, + GetRevisionByReplicaIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revisions_by_deployment_id import ( + GetRevisionsByDeploymentIdAction, + GetRevisionsByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.list_revisions import ( + ListRevisionsAction, + ListRevisionsActionResult, +) +from ai.backend.manager.services.deployment.actions.sync_replicas import ( + SyncReplicaAction, + SyncReplicaActionResult, +) +from ai.backend.manager.services.deployment.actions.update_deployment import ( + UpdateDeploymentAction, + UpdateDeploymentActionResult, +) from ai.backend.manager.sokovan.deployment import DeploymentController from ai.backend.manager.sokovan.deployment.types import DeploymentLifecycleType @@ -26,23 +122,96 @@ def __init__(self, deployment_controller: DeploymentController) -> None: """Initialize deployment service with controller.""" self._deployment_controller = deployment_controller - async def create(self, action: CreateDeploymentAction) -> CreateDeploymentActionResult: - """Create a new deployment. + async def create_deployment( + self, action: CreateDeploymentAction + ) -> CreateDeploymentActionResult: + return CreateDeploymentActionResult( + data=ModelDeploymentData( + id=uuid4(), + metadata=ModelDeploymentMetadataInfo( + name="test-deployment", + status=ModelDeploymentStatus.READY, + tags=["tag1", "tag2"], + project_id=uuid4(), + domain_name="default", + created_at=datetime.now(), + updated_at=datetime.now(), + ), + network_access=DeploymentNetworkSpec( + open_to_public=True, + url="http://example.com", + preferred_domain_name="example.com", + access_token_ids=[uuid4()], + ), + revision_history_ids=[uuid4(), uuid4()], + revision=mock_revision_data_1, + scaling_rule_ids=[uuid4(), uuid4()], + replica_state=ReplicaStateData( + desired_replica_count=3, + replica_ids=[uuid4(), uuid4(), uuid4()], + ), + default_deployment_strategy=DeploymentStrategy.ROLLING, + created_user_id=uuid4(), + ) + ) + + async def create_legacy_deployment( + self, action: CreateLegacyDeploymentAction + ) -> CreateLegacyDeploymentActionResult: + """Create a new legacy deployment(Model Serving). Args: - action: Create deployment action containing the creator specification + action: Create legacy deployment action containing the creator specification Returns: - CreateDeploymentActionResult: Result containing the created deployment info + CreateLegacyDeploymentActionResult: Result containing the created deployment info """ log.info("Creating deployment with name: {}", action.creator.name) deployment_info = await self._deployment_controller.create_deployment(action.creator) await self._deployment_controller.mark_lifecycle_needed( DeploymentLifecycleType.CHECK_PENDING ) - return CreateDeploymentActionResult(data=deployment_info) + return CreateLegacyDeploymentActionResult(data=deployment_info) + + async def update_deployment( + self, action: UpdateDeploymentAction + ) -> UpdateDeploymentActionResult: + await self._deployment_controller.mark_lifecycle_needed( + DeploymentLifecycleType.CHECK_REPLICA + ) + return UpdateDeploymentActionResult( + data=ModelDeploymentData( + id=action.deployment_id, + metadata=ModelDeploymentMetadataInfo( + name="test-deployment", + status=ModelDeploymentStatus.READY, + tags=["tag1", "tag2"], + project_id=uuid4(), + domain_name="default", + created_at=datetime.now(), + updated_at=datetime.now(), + ), + network_access=DeploymentNetworkSpec( + open_to_public=True, + url="http://example.com", + preferred_domain_name="example.com", + access_token_ids=[uuid4()], + ), + revision_history_ids=[uuid4(), uuid4()], + revision=mock_revision_data_1, + scaling_rule_ids=[uuid4(), uuid4()], + replica_state=ReplicaStateData( + desired_replica_count=3, + replica_ids=[uuid4(), uuid4(), uuid4()], + ), + default_deployment_strategy=DeploymentStrategy.ROLLING, + created_user_id=uuid4(), + ) + ) - async def destroy(self, action: DestroyDeploymentAction) -> DestroyDeploymentActionResult: + async def destroy_deployment( + self, action: DestroyDeploymentAction + ) -> DestroyDeploymentActionResult: """Destroy an existing deployment. Args: @@ -55,3 +224,194 @@ async def destroy(self, action: DestroyDeploymentAction) -> DestroyDeploymentAct success = await self._deployment_controller.destroy_deployment(action.endpoint_id) await self._deployment_controller.mark_lifecycle_needed(DeploymentLifecycleType.DESTROYING) return DestroyDeploymentActionResult(success=success) + + async def create_auto_scaling_rule( + self, action: CreateAutoScalingRuleAction + ) -> CreateAutoScalingRuleActionResult: + return CreateAutoScalingRuleActionResult( + data=ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=action.creator.model_deployment_id, + metric_source=action.creator.metric_source, + metric_name=action.creator.metric_name, + min_threshold=action.creator.min_threshold, + max_threshold=action.creator.max_threshold, + step_size=action.creator.step_size, + time_window=action.creator.time_window, + min_replicas=action.creator.min_replicas, + max_replicas=action.creator.max_replicas, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ) + ) + + async def update_auto_scaling_rule( + self, action: UpdateAutoScalingRuleAction + ) -> UpdateAutoScalingRuleActionResult: + return UpdateAutoScalingRuleActionResult( + data=ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=uuid4(), + metric_source=AutoScalingMetricSource.KERNEL, + metric_name="test-metric", + min_threshold=Decimal("0.5"), + max_threshold=Decimal("21.0"), + step_size=1, + time_window=60, + min_replicas=1, + max_replicas=10, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ) + ) + + async def delete_auto_scaling_rule( + self, action: DeleteAutoScalingRuleAction + ) -> DeleteAutoScalingRuleActionResult: + return DeleteAutoScalingRuleActionResult(success=True) + + async def create_access_token( + self, action: CreateAccessTokenAction + ) -> CreateAccessTokenActionResult: + return CreateAccessTokenActionResult( + data=ModelDeploymentAccessTokenData( + id=uuid4(), + token="test_token", + valid_until=datetime.now() + timedelta(hours=1), + created_at=datetime.now(), + ) + ) + + async def get_access_tokens_by_deployment_id( + self, action: GetAccessTokensByDeploymentIdAction + ) -> GetAccessTokensByDeploymentIdActionResult: + mock_tokens = [] + for i in range(3): + mock_tokens.append( + ModelDeploymentAccessTokenData( + id=uuid4(), + token=f"test_token_{i}", + valid_until=datetime.now() + timedelta(hours=24 * (i + 1)), + created_at=datetime.now() - timedelta(hours=i), + ) + ) + return GetAccessTokensByDeploymentIdActionResult(data=mock_tokens) + + async def sync_replicas(self, action: SyncReplicaAction) -> SyncReplicaActionResult: + return SyncReplicaActionResult(success=True) + + async def add_model_revision( + self, action: AddModelRevisionAction + ) -> AddModelRevisionActionResult: + return AddModelRevisionActionResult(revision=mock_revision_data_2) + + async def get_auto_scaling_rules_by_deployment_id( + self, action: GetAutoScalingRulesByDeploymentIdAction + ) -> GetAutoScalingRulesByDeploymentIdActionResult: + return GetAutoScalingRulesByDeploymentIdActionResult( + data=[ + ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=action.deployment_id, + metric_source=AutoScalingMetricSource.KERNEL, + metric_name="test-metric", + min_threshold=Decimal("0.5"), + max_threshold=Decimal("21.0"), + step_size=1, + time_window=60, + min_replicas=1, + max_replicas=10, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ), + ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=action.deployment_id, + metric_source=AutoScalingMetricSource.KERNEL, + metric_name="test-metric", + min_threshold=Decimal("0.0"), + max_threshold=Decimal("10.0"), + step_size=2, + time_window=200, + min_replicas=1, + max_replicas=5, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ), + ] + ) + + async def list_replicas(self, action: ListReplicasAction) -> ListReplicasActionResult: + return ListReplicasActionResult( + data=[], + total_count=0, + ) + + async def list_revisions(self, action: ListRevisionsAction) -> ListRevisionsActionResult: + return ListRevisionsActionResult(data=[], total_count=0) + + async def get_replicas_by_deployment_id( + self, action: GetReplicasByDeploymentIdAction + ) -> GetReplicasByDeploymentIdActionResult: + return GetReplicasByDeploymentIdActionResult(data=[]) + + async def get_revision_by_deployment_id( + self, action: GetRevisionByDeploymentIdAction + ) -> GetRevisionByDeploymentIdActionResult: + return GetRevisionByDeploymentIdActionResult(data=mock_revision_data_1) + + async def get_revision_by_id( + self, action: GetRevisionByIdAction + ) -> GetRevisionByIdActionResult: + return GetRevisionByIdActionResult(data=mock_revision_data_1) + + async def get_revision_by_replica_id( + self, action: GetRevisionByReplicaIdAction + ) -> GetRevisionByReplicaIdActionResult: + return GetRevisionByReplicaIdActionResult(data=mock_revision_data_1) + + async def get_deployment(self, action: GetDeploymentAction) -> GetDeploymentActionResult: + # For now, return mock deployment info + return GetDeploymentActionResult( + data=ModelDeploymentData( + id=action.deployment_id, + metadata=ModelDeploymentMetadataInfo( + name="test-deployment", + status=ModelDeploymentStatus.READY, + tags=["tag1", "tag2"], + project_id=uuid4(), + domain_name="default", + created_at=datetime.now(), + updated_at=datetime.now(), + ), + network_access=DeploymentNetworkSpec( + open_to_public=True, + url="http://example.com", + preferred_domain_name="example.com", + access_token_ids=[uuid4()], + ), + revision_history_ids=[uuid4(), uuid4()], + revision=mock_revision_data_1, + scaling_rule_ids=[uuid4(), uuid4()], + replica_state=ReplicaStateData( + desired_replica_count=3, + replica_ids=[uuid4(), uuid4(), uuid4()], + ), + default_deployment_strategy=DeploymentStrategy.ROLLING, + created_user_id=uuid4(), + ) + ) + + async def get_revisions_by_deployment_id( + self, action: GetRevisionsByDeploymentIdAction + ) -> GetRevisionsByDeploymentIdActionResult: + # For now, return mock revision data list + return GetRevisionsByDeploymentIdActionResult( + data=[mock_revision_data_1, mock_revision_data_2] + ) + + async def get_replicas_by_revision_id( + self, action: GetReplicasByRevisionIdAction + ) -> GetReplicasByRevisionIdActionResult: + # For now, return empty replica list + return GetReplicasByRevisionIdActionResult(data=[])