Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5672.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement API Layer of Model Deployment
21 changes: 11 additions & 10 deletions docs/manager/graphql-reference/supergraph.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -3009,10 +3009,10 @@ type ModelDeployment implements Node
metadata: ModelDeploymentMetadata!
networkAccess: ModelDeploymentNetworkAccess!
revision: ModelRevision
scalingRule: ScalingRule!
replicaState: ReplicaState!
defaultDeploymentStrategy: DeploymentStrategy!
createdUser: UserNode!
scalingRule: ScalingRule!
replicaState: ReplicaState!
revisionHistory(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection!
}

Expand Down Expand Up @@ -3046,10 +3046,10 @@ type ModelDeploymentMetadata
name: String!
status: DeploymentStatus!
tags: [String!]!
project: GroupNode!
domain: DomainNode!
createdAt: DateTime!
updatedAt: DateTime!
project: GroupNode!
domain: DomainNode!
}

"""Added in 25.13.0"""
Expand Down Expand Up @@ -3084,9 +3084,9 @@ input ModelDeploymentNetworkAccessInput
type ModelMountConfig
@join__type(graph: STRAWBERRY)
{
vfolder: VirtualFolderNode!
mountDestination: String!
definitionPath: String!
vfolder: VirtualFolderNode!
}

"""Added in 25.13.0"""
Expand All @@ -3105,7 +3105,6 @@ type ModelReplica implements Node
{
"""The Globally Unique ID of this object"""
id: ID!
revision: ModelRevision!

"""
This represents whether the replica has been checked and its health state.
Expand Down Expand Up @@ -3138,6 +3137,7 @@ type ModelReplica implements Node
The session ID associated with the replica. This can be null right after replica creation.
"""
session: ComputeSessionNode!
revision: ModelRevision!
}

"""Added in 25.13.0"""
Expand Down Expand Up @@ -3176,8 +3176,8 @@ type ModelRevision implements Node
modelRuntimeConfig: ModelRuntimeConfig!
modelMountConfig: ModelMountConfig!
extraMounts: ExtraVFolderMountConnection!
image: ImageNode!
createdAt: DateTime!
image: ImageNode!
}

"""Added in 25.13.0"""
Expand Down Expand Up @@ -4103,7 +4103,9 @@ type Mutation
"""Added in 25.13.0"""
addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload! @join__field(graph: STRAWBERRY)

"""Added in 25.13.0"""
"""
Added in 25.13.0. Create model revision which is not attached to any deployment.
"""
createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload! @join__field(graph: STRAWBERRY)

"""Added in 25.14.0"""
Expand Down Expand Up @@ -5066,8 +5068,6 @@ type ReservoirRegistryEdge
type ResourceConfig
@join__type(graph: STRAWBERRY)
{
resourceGroup: ScalingGroupNode!

"""
Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}"
"""
Expand All @@ -5077,6 +5077,7 @@ type ResourceConfig
Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}"
"""
resourceOpts: JSONString
resourceGroup: ScalingGroupNode!
}

"""Added in 25.13.0"""
Expand Down
21 changes: 11 additions & 10 deletions docs/manager/graphql-reference/v2-schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,10 @@ type ModelDeployment implements Node {
metadata: ModelDeploymentMetadata!
networkAccess: ModelDeploymentNetworkAccess!
revision: ModelRevision
scalingRule: ScalingRule!
replicaState: ReplicaState!
defaultDeploymentStrategy: DeploymentStrategy!
createdUser: UserNode!
scalingRule: ScalingRule!
replicaState: ReplicaState!
revisionHistory(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection!
}

Expand Down Expand Up @@ -741,10 +741,10 @@ type ModelDeploymentMetadata {
name: String!
status: DeploymentStatus!
tags: [String!]!
project: GroupNode!
domain: DomainNode!
createdAt: DateTime!
updatedAt: DateTime!
project: GroupNode!
domain: DomainNode!
}

"""Added in 25.13.0"""
Expand All @@ -771,9 +771,9 @@ input ModelDeploymentNetworkAccessInput {

"""Added in 25.13.0"""
type ModelMountConfig {
vfolder: VirtualFolderNode!
mountDestination: String!
definitionPath: String!
vfolder: VirtualFolderNode!
}

"""Added in 25.13.0"""
Expand All @@ -787,7 +787,6 @@ input ModelMountConfigInput {
type ModelReplica implements Node {
"""The Globally Unique ID of this object"""
id: ID!
revision: ModelRevision!

"""
This represents whether the replica has been checked and its health state.
Expand Down Expand Up @@ -820,6 +819,7 @@ type ModelReplica implements Node {
The session ID associated with the replica. This can be null right after replica creation.
"""
session: ComputeSessionNode!
revision: ModelRevision!
}

"""Added in 25.13.0"""
Expand Down Expand Up @@ -851,8 +851,8 @@ type ModelRevision implements Node {
modelRuntimeConfig: ModelRuntimeConfig!
modelMountConfig: ModelMountConfig!
extraMounts: ExtraVFolderMountConnection!
image: ImageNode!
createdAt: DateTime!
image: ImageNode!
}

"""Added in 25.13.0"""
Expand Down Expand Up @@ -961,7 +961,9 @@ type Mutation {
"""Added in 25.13.0"""
addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload!

"""Added in 25.13.0"""
"""
Added in 25.13.0. Create model revision which is not attached to any deployment.
"""
createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload!

"""Added in 25.14.0"""
Expand Down Expand Up @@ -1272,8 +1274,6 @@ type ReservoirRegistryEdge {

"""Added in 25.13.0"""
type ResourceConfig {
resourceGroup: ScalingGroupNode!

"""
Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}"
"""
Expand All @@ -1283,6 +1283,7 @@ type ResourceConfig {
Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}"
"""
resourceOpts: JSONString
resourceGroup: ScalingGroupNode!
}

"""Added in 25.13.0"""
Expand Down
14 changes: 14 additions & 0 deletions src/ai/backend/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class ErrorDomain(enum.StrEnum):
PERMISSION = "permission"
METRIC = "metric"
STORAGE_PROXY = "storage-proxy"
MODEL_DEPLOYMENT = "model-deployment"


class ErrorOperation(enum.StrEnum):
Expand Down Expand Up @@ -643,3 +644,16 @@ def error_code(cls) -> ErrorCode:
operation=ErrorOperation.READ,
error_detail=ErrorDetail.NOT_FOUND,
)


class ModelDeploymentUnavailableError(BackendAIError, web.HTTPServiceUnavailable):
error_type = "https://api.backend.ai/probs/model-deployment-unavailable"
error_title = "Model Deployment Unavailable"

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.MODEL_DEPLOYMENT,
operation=ErrorOperation.EXECUTE,
error_detail=ErrorDetail.UNAVAILABLE,
)
13 changes: 12 additions & 1 deletion src/ai/backend/manager/api/gql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, Type, cast

import graphene
import orjson
import strawberry
from graphql import StringValueNode
Expand Down Expand Up @@ -160,7 +161,17 @@ def from_resource_slot(resource_slot: ResourceSlot) -> JSONString:
return JSONString.serialize(resource_slot.to_json())


def to_global_id(type_: Type[Any], local_id: uuid.UUID | str) -> str:
def to_global_id(
type_: Type[Any], local_id: uuid.UUID | str, is_target_graphene_object: bool = False
) -> str:
if is_target_graphene_object:
# For compatibility with existing Graphene-based global IDs
if not issubclass(type_, graphene.ObjectType):
raise TypeError(
"type_ must be a graphene ObjectType when is_target_graphene_object is True."
)
typename = type_.__name__
return base64(f"{typename}:{local_id}")
if not has_object_definition(type_):
raise TypeError("type_ must be a Strawberry object type (Node or Edge).")
typename = get_object_definition(type_, strict=True).name
Expand Down
92 changes: 53 additions & 39 deletions src/ai/backend/manager/api/gql/model_deployment/access_token.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from datetime import datetime, timedelta
from collections.abc import Sequence
from datetime import datetime
from typing import Self
from uuid import UUID

import strawberry
from strawberry import ID, Info
from strawberry.relay import Connection, Edge, Node, NodeID

from ai.backend.common.exception import ModelDeploymentUnavailableError
from ai.backend.manager.api.gql.types import StrawberryGQLContext
from ai.backend.manager.data.deployment.access_token import ModelDeploymentAccessTokenCreator
from ai.backend.manager.data.deployment.types import ModelDeploymentAccessTokenData
from ai.backend.manager.services.deployment.actions.access_token.create_access_token import (
CreateAccessTokenAction,
)
from ai.backend.manager.services.deployment.actions.access_token.get_access_tokens_by_deployment_id import (
GetAccessTokensByDeploymentIdAction,
)


@strawberry.type
class AccessToken(Node):
id: NodeID
id: NodeID[str]
token: str = strawberry.field(description="Added in 25.13.0: The access token.")
created_at: datetime = strawberry.field(
description="Added in 25.13.0: The creation timestamp of the access token."
Expand All @@ -19,6 +30,34 @@ class AccessToken(Node):
description="Added in 25.13.0: The expiration timestamp of the access token."
)

@classmethod
def from_dataclass(cls, data: ModelDeploymentAccessTokenData) -> Self:
return cls(
id=ID(str(data.id)),
token=data.token,
created_at=data.created_at,
valid_until=data.valid_until,
)

@classmethod
async def batch_load_by_deployment_ids(
cls, ctx: StrawberryGQLContext, deployment_ids: Sequence[UUID]
) -> list[list[ModelDeploymentAccessTokenData]]:
"""Batch load access tokens by deployment IDs."""
processor = ctx.processors.deployment
if processor is None:
raise ModelDeploymentUnavailableError(
"Model Deployment feature is unavailable. Please contact support."
)

results = []
for deployment_id in deployment_ids:
result = await processor.get_access_tokens_by_deployment_id.wait_for_complete(
GetAccessTokensByDeploymentIdAction(deployment_id=deployment_id)
)
results.append(result.data if result else [])
return results


AccessTokenEdge = Edge[AccessToken]

Expand All @@ -32,42 +71,6 @@ def __init__(self, *args, count: int, **kwargs):
self.count = count


mock_access_token_1 = AccessToken(
id=UUID("13cd8325-9307-49e4-94eb-ded2581363f8"),
token="mock-token-1",
created_at=datetime.now(),
valid_until=datetime.now() + timedelta(hours=12),
)

mock_access_token_2 = AccessToken(
id=UUID("dc1a223a-7437-4e6f-aedf-23417d0486dd"),
token="mock-token-2",
created_at=datetime.now(),
valid_until=datetime.now() + timedelta(hours=1),
)

mock_access_token_3 = AccessToken(
id=UUID("39f8b49e-0ddf-4dfb-92d6-003c771684b7"),
token="mock-token-3",
created_at=datetime.now(),
valid_until=datetime.now() + timedelta(hours=100),
)

mock_access_token_4 = AccessToken(
id=UUID("85a6ed1e-133b-4f58-9c06-f667337c6111"),
token="mock-token-4",
created_at=datetime.now(),
valid_until=datetime.now() + timedelta(hours=10),
)

mock_access_token_5 = AccessToken(
id=UUID("c42f8578-b31d-4203-b858-93f93b4b9549"),
token="mock-token-5",
created_at=datetime.now(),
valid_until=datetime.now() + timedelta(hours=3),
)


@strawberry.input
class CreateAccessTokenInput:
model_deployment_id: ID = strawberry.field(
Expand All @@ -77,6 +80,12 @@ class CreateAccessTokenInput:
description="Added in 25.13.0: The expiration timestamp of the access token."
)

def to_creator(self) -> "ModelDeploymentAccessTokenCreator":
return ModelDeploymentAccessTokenCreator(
model_deployment_id=UUID(self.model_deployment_id),
valid_until=self.valid_until,
)


@strawberry.type
class CreateAccessTokenPayload:
Expand All @@ -87,4 +96,9 @@ class CreateAccessTokenPayload:
async def create_access_token(
input: CreateAccessTokenInput, info: Info[StrawberryGQLContext]
) -> CreateAccessTokenPayload:
return CreateAccessTokenPayload(access_token=mock_access_token_1)
deployment_processor = info.context.processors.deployment
assert deployment_processor is not None
result = await deployment_processor.create_access_token.wait_for_complete(
action=CreateAccessTokenAction(input.to_creator())
)
return CreateAccessTokenPayload(access_token=AccessToken.from_dataclass(result.data))
Loading
Loading