1
1
from datetime import datetime , timedelta
2
+ from typing import Self
2
3
from uuid import UUID
3
4
4
5
import strawberry
5
6
from strawberry import ID , Info
6
7
from strawberry .relay import Connection , Edge , Node , NodeID
7
8
8
9
from ai .backend .manager .api .gql .types import StrawberryGQLContext
10
+ from ai .backend .manager .data .deployment .creator import ModelDeploymentAccessTokenCreator
11
+ from ai .backend .manager .services .deployment .actions .create_access_token import (
12
+ CreateAccessTokenAction ,
13
+ )
9
14
10
15
11
16
@strawberry .type
12
17
class AccessToken (Node ):
13
- id : NodeID
18
+ id : NodeID [ str ]
14
19
token : str = strawberry .field (description = "Added in 25.13.0: The access token." )
15
20
created_at : datetime = strawberry .field (
16
21
description = "Added in 25.13.0: The creation timestamp of the access token."
@@ -19,6 +24,15 @@ class AccessToken(Node):
19
24
description = "Added in 25.13.0: The expiration timestamp of the access token."
20
25
)
21
26
27
+ @classmethod
28
+ def from_dataclass (cls , data ) -> Self :
29
+ return cls (
30
+ id = ID (str (data .id )),
31
+ token = data .token ,
32
+ created_at = data .created_at ,
33
+ valid_until = data .valid_until ,
34
+ )
35
+
22
36
23
37
AccessTokenEdge = Edge [AccessToken ]
24
38
@@ -32,36 +46,36 @@ def __init__(self, *args, count: int, **kwargs):
32
46
self .count = count
33
47
34
48
35
- mock_access_token_1 = AccessToken (
36
- id = UUID ( "13cd8325-9307-49e4-94eb-ded2581363f8" ) ,
49
+ mock_access_token_1 : AccessToken = AccessToken (
50
+ id = "13cd8325-9307-49e4-94eb-ded2581363f8" ,
37
51
token = "mock-token-1" ,
38
52
created_at = datetime .now (),
39
53
valid_until = datetime .now () + timedelta (hours = 12 ),
40
54
)
41
55
42
- mock_access_token_2 = AccessToken (
43
- id = UUID ( "dc1a223a-7437-4e6f-aedf-23417d0486dd" ) ,
56
+ mock_access_token_2 : AccessToken = AccessToken (
57
+ id = "dc1a223a-7437-4e6f-aedf-23417d0486dd" ,
44
58
token = "mock-token-2" ,
45
59
created_at = datetime .now (),
46
60
valid_until = datetime .now () + timedelta (hours = 1 ),
47
61
)
48
62
49
- mock_access_token_3 = AccessToken (
50
- id = UUID ( "39f8b49e-0ddf-4dfb-92d6-003c771684b7" ) ,
63
+ mock_access_token_3 : AccessToken = AccessToken (
64
+ id = "39f8b49e-0ddf-4dfb-92d6-003c771684b7" ,
51
65
token = "mock-token-3" ,
52
66
created_at = datetime .now (),
53
67
valid_until = datetime .now () + timedelta (hours = 100 ),
54
68
)
55
69
56
- mock_access_token_4 = AccessToken (
57
- id = UUID ( "85a6ed1e-133b-4f58-9c06-f667337c6111" ) ,
70
+ mock_access_token_4 : AccessToken = AccessToken (
71
+ id = "85a6ed1e-133b-4f58-9c06-f667337c6111" ,
58
72
token = "mock-token-4" ,
59
73
created_at = datetime .now (),
60
74
valid_until = datetime .now () + timedelta (hours = 10 ),
61
75
)
62
76
63
- mock_access_token_5 = AccessToken (
64
- id = UUID ( "c42f8578-b31d-4203-b858-93f93b4b9549" ) ,
77
+ mock_access_token_5 : AccessToken = AccessToken (
78
+ id = "c42f8578-b31d-4203-b858-93f93b4b9549" ,
65
79
token = "mock-token-5" ,
66
80
created_at = datetime .now (),
67
81
valid_until = datetime .now () + timedelta (hours = 3 ),
@@ -77,6 +91,12 @@ class CreateAccessTokenInput:
77
91
description = "Added in 25.13.0: The expiration timestamp of the access token."
78
92
)
79
93
94
+ def to_creator (self ) -> "ModelDeploymentAccessTokenCreator" :
95
+ return ModelDeploymentAccessTokenCreator (
96
+ model_deployment_id = UUID (self .model_deployment_id ),
97
+ valid_until = self .valid_until ,
98
+ )
99
+
80
100
81
101
@strawberry .type
82
102
class CreateAccessTokenPayload :
@@ -87,4 +107,9 @@ class CreateAccessTokenPayload:
87
107
async def create_access_token (
88
108
input : CreateAccessTokenInput , info : Info [StrawberryGQLContext ]
89
109
) -> CreateAccessTokenPayload :
90
- return CreateAccessTokenPayload (access_token = mock_access_token_1 )
110
+ deployment_processor = info .context .processors .deployment
111
+ assert deployment_processor is not None
112
+ result = await deployment_processor .create_access_token .wait_for_complete (
113
+ action = CreateAccessTokenAction (input .to_creator ())
114
+ )
115
+ return CreateAccessTokenPayload (access_token = AccessToken .from_dataclass (result .data ))
0 commit comments