Skip to content

Commit 0cae892

Browse files
committed
feat: Implment basic api layers for model deployment gql types
1 parent be27547 commit 0cae892

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2834
-769
lines changed

docs/manager/graphql-reference/supergraph.graphql

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,11 +2991,11 @@ type ModelDeployment implements Node
29912991
metadata: ModelDeploymentMetadata!
29922992
networkAccess: ModelDeploymentNetworkAccess!
29932993
revision: ModelRevision
2994-
revisionHistory: ModelRevisionConnection!
2995-
scalingRule: ScalingRule!
2996-
replicaState: ReplicaState!
29972994
defaultDeploymentStrategy: DeploymentStrategy!
29982995
createdUser: UserNode!
2996+
scalingRule: ScalingRule!
2997+
replicaState: ReplicaState!
2998+
revisionHistory: ModelRevisionConnection!
29992999
}
30003000

30013001
"""Added in 25.13.0"""
@@ -3028,10 +3028,10 @@ type ModelDeploymentMetadata
30283028
name: String!
30293029
status: DeploymentStatus!
30303030
tags: [String!]!
3031-
project: GroupNode!
3032-
domain: DomainNode!
30333031
createdAt: DateTime!
30343032
updatedAt: DateTime!
3033+
project: GroupNode!
3034+
domain: DomainNode!
30353035
}
30363036

30373037
"""Added in 25.13.0"""
@@ -3066,9 +3066,9 @@ input ModelDeploymentNetworkAccessInput
30663066
type ModelMountConfig
30673067
@join__type(graph: STRAWBERRY)
30683068
{
3069-
vfolder: VirtualFolderNode!
30703069
mountDestination: String!
30713070
definitionPath: String!
3071+
vfolder: VirtualFolderNode!
30723072
}
30733073

30743074
"""Added in 25.13.0"""
@@ -3087,7 +3087,6 @@ type ModelReplica implements Node
30873087
{
30883088
"""The Globally Unique ID of this object"""
30893089
id: ID!
3090-
revision: ModelRevision!
30913090

30923091
"""
30933092
This represents whether the replica has been checked and its health state.
@@ -3120,6 +3119,7 @@ type ModelReplica implements Node
31203119
The session ID associated with the replica. This can be null right after replica creation.
31213120
"""
31223121
session: ComputeSessionNode!
3122+
revision: ModelRevision!
31233123
}
31243124

31253125
"""Added in 25.13.0"""
@@ -3158,8 +3158,8 @@ type ModelRevision implements Node
31583158
modelRuntimeConfig: ModelRuntimeConfig!
31593159
modelMountConfig: ModelMountConfig!
31603160
extraMounts: ExtraVFolderMountConnection!
3161-
image: ImageNode!
31623161
createdAt: DateTime!
3162+
image: ImageNode!
31633163
}
31643164

31653165
"""Added in 25.13.0"""
@@ -4084,7 +4084,9 @@ type Mutation
40844084
"""Added in 25.13.0"""
40854085
addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload! @join__field(graph: STRAWBERRY)
40864086

4087-
"""Added in 25.13.0"""
4087+
"""
4088+
Added in 25.13.0. Create model revision which is not attached to any deployment.
4089+
"""
40884090
createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload! @join__field(graph: STRAWBERRY)
40894091

40904092
"""Added in 25.14.0"""
@@ -4838,7 +4840,7 @@ type Query
48384840
revisions(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! @join__field(graph: STRAWBERRY)
48394841

48404842
"""Added in 25.13.0"""
4841-
revision(id: ID!): ModelRevision @join__field(graph: STRAWBERRY)
4843+
revision(id: ID!): ModelRevision! @join__field(graph: STRAWBERRY)
48424844

48434845
"""Added in 25.13.0"""
48444846
replicas(filter: ReplicaFilter = null, orderBy: [ReplicaOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelReplicaConnection! @join__field(graph: STRAWBERRY)
@@ -5044,8 +5046,6 @@ type ReservoirRegistryEdge
50445046
type ResourceConfig
50455047
@join__type(graph: STRAWBERRY)
50465048
{
5047-
resourceGroup: ScalingGroupNode!
5048-
50495049
"""
50505050
Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}"
50515051
"""
@@ -5055,6 +5055,7 @@ type ResourceConfig
50555055
Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}"
50565056
"""
50575057
resourceOpts: JSONString
5058+
resourceGroup: ScalingGroupNode!
50585059
}
50595060

50605061
"""Added in 25.13.0"""

docs/manager/graphql-reference/v2-schema.graphql

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ The `JSON` scalar type represents JSON values as specified by [ECMA-404](https:/
670670
"""
671671
scalar JSON @specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf")
672672

673-
"""A custom scalar for JSON strings using orjson"""
673+
"""Added in 25.13.0"""
674674
scalar JSONString
675675

676676
"""
@@ -696,11 +696,11 @@ type ModelDeployment implements Node {
696696
metadata: ModelDeploymentMetadata!
697697
networkAccess: ModelDeploymentNetworkAccess!
698698
revision: ModelRevision
699-
revisionHistory: ModelRevisionConnection!
700-
scalingRule: ScalingRule!
701-
replicaState: ReplicaState!
702699
defaultDeploymentStrategy: DeploymentStrategy!
703700
createdUser: UserNode!
701+
scalingRule: ScalingRule!
702+
replicaState: ReplicaState!
703+
revisionHistory: ModelRevisionConnection!
704704
}
705705

706706
"""Added in 25.13.0"""
@@ -727,10 +727,10 @@ type ModelDeploymentMetadata {
727727
name: String!
728728
status: DeploymentStatus!
729729
tags: [String!]!
730-
project: GroupNode!
731-
domain: DomainNode!
732730
createdAt: DateTime!
733731
updatedAt: DateTime!
732+
project: GroupNode!
733+
domain: DomainNode!
734734
}
735735

736736
"""Added in 25.13.0"""
@@ -757,9 +757,9 @@ input ModelDeploymentNetworkAccessInput {
757757

758758
"""Added in 25.13.0"""
759759
type ModelMountConfig {
760-
vfolder: VirtualFolderNode!
761760
mountDestination: String!
762761
definitionPath: String!
762+
vfolder: VirtualFolderNode!
763763
}
764764

765765
"""Added in 25.13.0"""
@@ -773,7 +773,6 @@ input ModelMountConfigInput {
773773
type ModelReplica implements Node {
774774
"""The Globally Unique ID of this object"""
775775
id: ID!
776-
revision: ModelRevision!
777776

778777
"""
779778
This represents whether the replica has been checked and its health state.
@@ -806,6 +805,7 @@ type ModelReplica implements Node {
806805
The session ID associated with the replica. This can be null right after replica creation.
807806
"""
808807
session: ComputeSessionNode!
808+
revision: ModelRevision!
809809
}
810810

811811
"""Added in 25.13.0"""
@@ -837,8 +837,8 @@ type ModelRevision implements Node {
837837
modelRuntimeConfig: ModelRuntimeConfig!
838838
modelMountConfig: ModelMountConfig!
839839
extraMounts: ExtraVFolderMountConnection!
840-
image: ImageNode!
841840
createdAt: DateTime!
841+
image: ImageNode!
842842
}
843843

844844
"""Added in 25.13.0"""
@@ -946,7 +946,9 @@ type Mutation {
946946
"""Added in 25.13.0"""
947947
addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload!
948948

949-
"""Added in 25.13.0"""
949+
"""
950+
Added in 25.13.0. Create model revision which is not attached to any deployment.
951+
"""
950952
createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload!
951953

952954
"""Added in 25.14.0"""
@@ -1118,7 +1120,7 @@ type Query {
11181120
revisions(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection!
11191121

11201122
"""Added in 25.13.0"""
1121-
revision(id: ID!): ModelRevision
1123+
revision(id: ID!): ModelRevision!
11221124

11231125
"""Added in 25.13.0"""
11241126
replicas(filter: ReplicaFilter = null, orderBy: [ReplicaOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelReplicaConnection!
@@ -1254,8 +1256,6 @@ type ReservoirRegistryEdge {
12541256

12551257
"""Added in 25.13.0"""
12561258
type ResourceConfig {
1257-
resourceGroup: ScalingGroupNode!
1258-
12591259
"""
12601260
Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}"
12611261
"""
@@ -1265,6 +1265,7 @@ type ResourceConfig {
12651265
Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}"
12661266
"""
12671267
resourceOpts: JSONString
1268+
resourceGroup: ScalingGroupNode!
12681269
}
12691270

12701271
"""Added in 25.13.0"""

src/ai/backend/common/exception.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class ErrorDomain(enum.StrEnum):
165165
PERMISSION = "permission"
166166
METRIC = "metric"
167167
STORAGE_PROXY = "storage-proxy"
168+
MODEL_DEPLOYMENT = "model-deployment"
168169

169170

170171
class ErrorOperation(enum.StrEnum):
@@ -643,3 +644,16 @@ def error_code(cls) -> ErrorCode:
643644
operation=ErrorOperation.READ,
644645
error_detail=ErrorDetail.NOT_FOUND,
645646
)
647+
648+
649+
class ModelDeploymentUnavailableError(BackendAIError, web.HTTPServiceUnavailable):
650+
error_type = "https://api.backend.ai/probs/model-deployment-unavailable"
651+
error_title = "Model Deployment Unavailable"
652+
653+
@classmethod
654+
def error_code(cls) -> ErrorCode:
655+
return ErrorCode(
656+
domain=ErrorDomain.MODEL_DEPLOYMENT,
657+
operation=ErrorOperation.EXECUTE,
658+
error_detail=ErrorDetail.UNAVAILABLE,
659+
)

src/ai/backend/manager/api/admin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from ai.backend.common.dto.manager.response import GraphQLResponse
2424
from ai.backend.logging import BraceStyleAdapter
2525
from ai.backend.manager.api.gql.types import StrawberryGQLContext
26-
from ai.backend.manager.dto.context import ProcessorsCtx
26+
from ai.backend.manager.dto.context import (
27+
ProcessorsCtx,
28+
)
2729

2830
from ..api.gql.schema import schema as strawberry_schema
2931
from ..errors.api import GraphQLError as BackendGQLError

src/ai/backend/manager/api/gql/base.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from __future__ import annotations
22

33
import uuid
4+
from collections.abc import Mapping
45
from enum import StrEnum
5-
from typing import TYPE_CHECKING, Any, Optional, Type
6+
from typing import TYPE_CHECKING, Any, Optional, Type, cast
67

78
import orjson
89
import strawberry
910
from graphql import StringValueNode
1011
from graphql_relay.utils import base64, unbase64
1112
from strawberry.types import get_object_definition, has_object_definition
1213

14+
from ai.backend.common.types import ResourceSlot
15+
1316
if TYPE_CHECKING:
1417
from ai.backend.manager.types import (
1518
PaginationOptions,
@@ -133,33 +136,28 @@ class Ordering(StrEnum):
133136
DESC_NULLS_LAST = "DESC_NULLS_LAST"
134137

135138

136-
def serialize_json(value: Any) -> str:
137-
if isinstance(value, (dict, list)):
138-
return orjson.dumps(value).decode("utf-8")
139-
elif isinstance(value, str):
140-
return value
141-
else:
142-
return orjson.dumps(value).decode("utf-8")
143-
144-
145-
def parse_json(value: str | bytes) -> Any:
146-
if isinstance(value, str):
147-
return orjson.loads(value)
148-
elif isinstance(value, bytes):
149-
return orjson.loads(value)
150-
else:
139+
@strawberry.scalar(description="Added in 25.13.0")
140+
class JSONString:
141+
@staticmethod
142+
def parse_value(value: str | bytes) -> Mapping[str, Any]:
143+
if isinstance(value, str):
144+
return orjson.loads(value)
145+
if isinstance(value, bytes):
146+
return orjson.loads(value)
151147
return value
152148

149+
@staticmethod
150+
def serialize(value: Any) -> JSONString:
151+
if isinstance(value, (dict, list)):
152+
return cast(JSONString, orjson.dumps(value).decode("utf-8"))
153+
elif isinstance(value, str):
154+
return cast(JSONString, value)
155+
else:
156+
return cast(JSONString, orjson.dumps(value).decode("utf-8"))
153157

154-
@strawberry.scalar(
155-
name="JSONString",
156-
description="A custom scalar for JSON strings using orjson",
157-
serialize=serialize_json,
158-
parse_value=parse_json,
159-
parse_literal=lambda v: parse_json(v.value) if hasattr(v, "value") else v,
160-
)
161-
class JSONString:
162-
pass
158+
@staticmethod
159+
def from_resource_slot(resource_slot: ResourceSlot) -> JSONString:
160+
return JSONString.serialize(resource_slot.to_json())
163161

164162

165163
def to_global_id(type_: Type[Any], local_id: uuid.UUID | str) -> str:

0 commit comments

Comments
 (0)