Skip to content

Commit 87b5dd4

Browse files
committed
feat: Implement basic model deployment api layer
1 parent 055eade commit 87b5dd4

File tree

21 files changed

+719
-60
lines changed

21 files changed

+719
-60
lines changed

docs/manager/graphql-reference/schema.graphql

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ type Query {
128128
"""Added in 24.03.1"""
129129
id: String
130130
reference: String
131-
architecture: String = "x86_64"
131+
architecture: String = "aarch64"
132132
): Image
133133
images(
134134
"""
@@ -2331,7 +2331,7 @@ type Mutation {
23312331
): RescanImages
23322332
preload_image(references: [String]!, target_agents: [String]!): PreloadImage
23332333
unload_image(references: [String]!, target_agents: [String]!): UnloadImage
2334-
modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage
2334+
modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage
23352335

23362336
"""Added in 25.6.0"""
23372337
clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload
@@ -2340,7 +2340,7 @@ type Mutation {
23402340
forget_image_by_id(image_id: String!): ForgetImageById
23412341

23422342
"""Deprecated since 25.4.0. Use `forget_image_by_id` instead."""
2343-
forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")
2343+
forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")
23442344

23452345
"""Added in 25.4.0"""
23462346
purge_image_by_id(
@@ -2352,7 +2352,7 @@ type Mutation {
23522352

23532353
"""Added in 24.03.1"""
23542354
untag_image_from_registry(image_id: String!): UntagImageFromRegistry
2355-
alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage
2355+
alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage
23562356
dealias_image(alias: String!): DealiasImage
23572357
clear_images(registry: String): ClearImages
23582358

@@ -2924,7 +2924,7 @@ type ClearImageCustomResourceLimitPayload {
29242924
"""Added in 25.6.0."""
29252925
input ClearImageCustomResourceLimitKey {
29262926
image_canonical: String!
2927-
architecture: String! = "x86_64"
2927+
architecture: String! = "aarch64"
29282928
}
29292929

29302930
"""Added in 24.03.0."""

docs/manager/graphql-reference/supergraph.graphql

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ input ClearImageCustomResourceLimitKey
718718
@join__type(graph: GRAPHENE)
719719
{
720720
image_canonical: String!
721-
architecture: String! = "x86_64"
721+
architecture: String! = "aarch64"
722722
}
723723

724724
"""Added in 25.6.0."""
@@ -3745,7 +3745,7 @@ type Mutation
37453745
): RescanImages @join__field(graph: GRAPHENE)
37463746
preload_image(references: [String]!, target_agents: [String]!): PreloadImage @join__field(graph: GRAPHENE)
37473747
unload_image(references: [String]!, target_agents: [String]!): UnloadImage @join__field(graph: GRAPHENE)
3748-
modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE)
3748+
modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE)
37493749

37503750
"""Added in 25.6.0"""
37513751
clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload @join__field(graph: GRAPHENE)
@@ -3754,7 +3754,7 @@ type Mutation
37543754
forget_image_by_id(image_id: String!): ForgetImageById @join__field(graph: GRAPHENE)
37553755

37563756
"""Deprecated since 25.4.0. Use `forget_image_by_id` instead."""
3757-
forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")
3757+
forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")
37583758

37593759
"""Added in 25.4.0"""
37603760
purge_image_by_id(
@@ -3766,7 +3766,7 @@ type Mutation
37663766

37673767
"""Added in 24.03.1"""
37683768
untag_image_from_registry(image_id: String!): UntagImageFromRegistry @join__field(graph: GRAPHENE)
3769-
alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage @join__field(graph: GRAPHENE)
3769+
alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage @join__field(graph: GRAPHENE)
37703770
dealias_image(alias: String!): DealiasImage @join__field(graph: GRAPHENE)
37713771
clear_images(registry: String): ClearImages @join__field(graph: GRAPHENE)
37723772

@@ -4446,7 +4446,7 @@ type Query
44464446
"""Added in 24.03.1"""
44474447
id: String
44484448
reference: String
4449-
architecture: String = "x86_64"
4449+
architecture: String = "aarch64"
44504450
): Image @join__field(graph: GRAPHENE)
44514451
images(
44524452
"""

src/ai/backend/manager/api/gql/model_deployment/access_token.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from datetime import datetime, timedelta
2+
from typing import Self
23
from uuid import UUID
34

45
import strawberry
56
from strawberry import ID, Info
67
from strawberry.relay import Connection, Edge, Node, NodeID
78

89
from ai.backend.manager.api.gql.types import StrawberryGQLContext
10+
from ai.backend.manager.data.deployment.creator import ModelDeploymentAccessTokenCreator
11+
from ai.backend.manager.services.deployment.actions.create_access_token import (
12+
CreateAccessTokenAction,
13+
)
914

1015

1116
@strawberry.type
1217
class AccessToken(Node):
13-
id: NodeID
18+
id: NodeID[str]
1419
token: str = strawberry.field(description="Added in 25.13.0: The access token.")
1520
created_at: datetime = strawberry.field(
1621
description="Added in 25.13.0: The creation timestamp of the access token."
@@ -19,6 +24,15 @@ class AccessToken(Node):
1924
description="Added in 25.13.0: The expiration timestamp of the access token."
2025
)
2126

27+
@classmethod
28+
def from_dataclass(cls, data) -> Self:
29+
return cls(
30+
id=ID(str(data.id)),
31+
token=data.token,
32+
created_at=data.created_at,
33+
valid_until=data.valid_until,
34+
)
35+
2236

2337
AccessTokenEdge = Edge[AccessToken]
2438

@@ -32,36 +46,36 @@ def __init__(self, *args, count: int, **kwargs):
3246
self.count = count
3347

3448

35-
mock_access_token_1 = AccessToken(
36-
id=UUID("13cd8325-9307-49e4-94eb-ded2581363f8"),
49+
mock_access_token_1: AccessToken = AccessToken(
50+
id="13cd8325-9307-49e4-94eb-ded2581363f8",
3751
token="mock-token-1",
3852
created_at=datetime.now(),
3953
valid_until=datetime.now() + timedelta(hours=12),
4054
)
4155

42-
mock_access_token_2 = AccessToken(
43-
id=UUID("dc1a223a-7437-4e6f-aedf-23417d0486dd"),
56+
mock_access_token_2: AccessToken = AccessToken(
57+
id="dc1a223a-7437-4e6f-aedf-23417d0486dd",
4458
token="mock-token-2",
4559
created_at=datetime.now(),
4660
valid_until=datetime.now() + timedelta(hours=1),
4761
)
4862

49-
mock_access_token_3 = AccessToken(
50-
id=UUID("39f8b49e-0ddf-4dfb-92d6-003c771684b7"),
63+
mock_access_token_3: AccessToken = AccessToken(
64+
id="39f8b49e-0ddf-4dfb-92d6-003c771684b7",
5165
token="mock-token-3",
5266
created_at=datetime.now(),
5367
valid_until=datetime.now() + timedelta(hours=100),
5468
)
5569

56-
mock_access_token_4 = AccessToken(
57-
id=UUID("85a6ed1e-133b-4f58-9c06-f667337c6111"),
70+
mock_access_token_4: AccessToken = AccessToken(
71+
id="85a6ed1e-133b-4f58-9c06-f667337c6111",
5872
token="mock-token-4",
5973
created_at=datetime.now(),
6074
valid_until=datetime.now() + timedelta(hours=10),
6175
)
6276

63-
mock_access_token_5 = AccessToken(
64-
id=UUID("c42f8578-b31d-4203-b858-93f93b4b9549"),
77+
mock_access_token_5: AccessToken = AccessToken(
78+
id="c42f8578-b31d-4203-b858-93f93b4b9549",
6579
token="mock-token-5",
6680
created_at=datetime.now(),
6781
valid_until=datetime.now() + timedelta(hours=3),
@@ -77,6 +91,12 @@ class CreateAccessTokenInput:
7791
description="Added in 25.13.0: The expiration timestamp of the access token."
7892
)
7993

94+
def to_creator(self) -> "ModelDeploymentAccessTokenCreator":
95+
return ModelDeploymentAccessTokenCreator(
96+
model_deployment_id=UUID(self.model_deployment_id),
97+
valid_until=self.valid_until,
98+
)
99+
80100

81101
@strawberry.type
82102
class CreateAccessTokenPayload:
@@ -87,4 +107,9 @@ class CreateAccessTokenPayload:
87107
async def create_access_token(
88108
input: CreateAccessTokenInput, info: Info[StrawberryGQLContext]
89109
) -> CreateAccessTokenPayload:
90-
return CreateAccessTokenPayload(access_token=mock_access_token_1)
110+
deployment_processor = info.context.processors.deployment
111+
assert deployment_processor is not None
112+
result = await deployment_processor.create_access_token.wait_for_complete(
113+
action=CreateAccessTokenAction(input.to_creator())
114+
)
115+
return CreateAccessTokenPayload(access_token=AccessToken.from_dataclass(result.data))

0 commit comments

Comments
 (0)