Skip to content

Commit 3c9040b

Browse files
committed
feat: Add RBAC DB source
1 parent 2cf1a22 commit 3c9040b

File tree

8 files changed

+222
-221
lines changed

8 files changed

+222
-221
lines changed

src/ai/backend/manager/data/permission/object_permission.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,23 @@ class ObjectPermissionCreateInput:
1717
status: PermissionStatus = PermissionStatus.ACTIVE
1818

1919

20+
@dataclass
21+
class ObjectPermissionCreateInputBeforeRoleCreation:
22+
entity_type: EntityType
23+
entity_id: str
24+
operation: OperationType
25+
status: PermissionStatus = PermissionStatus.ACTIVE
26+
27+
def to_input(self, role_id: uuid.UUID) -> ObjectPermissionCreateInput:
28+
return ObjectPermissionCreateInput(
29+
role_id=role_id,
30+
entity_type=self.entity_type,
31+
entity_id=self.entity_id,
32+
operation=self.operation,
33+
status=self.status,
34+
)
35+
36+
2037
@dataclass
2138
class ObjectPermissionUpdater:
2239
id: uuid.UUID

src/ai/backend/manager/data/permission/permission_group.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ class PermissionGroupCreator:
1010
scope_id: ScopeId
1111

1212

13+
@dataclass
14+
class PermissionGroupCreatorBeforeRoleCreation:
15+
scope_id: ScopeId
16+
17+
def to_input(self, role_id: uuid.UUID) -> PermissionGroupCreator:
18+
return PermissionGroupCreator(
19+
role_id=role_id,
20+
scope_id=self.scope_id,
21+
)
22+
23+
1324
@dataclass
1425
class PermissionGroupData:
1526
id: uuid.UUID

src/ai/backend/manager/data/permission/role.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55

66
from ai.backend.manager.types import OptionalState, PartialModifier, TriState
77

8-
from .object_permission import ObjectPermissionData
9-
from .scope_permission import ScopePermissionData
8+
from .object_permission import (
9+
ObjectPermissionCreateInputBeforeRoleCreation,
10+
ObjectPermissionData,
11+
)
12+
from .permission_group import PermissionGroupCreatorBeforeRoleCreation
1013
from .status import RoleStatus
1114
from .types import EntityType, RoleSource
1215

@@ -18,8 +21,10 @@ class RoleCreateInput:
1821
status: RoleStatus = RoleStatus.ACTIVE
1922
description: Optional[str] = None
2023

21-
scope_permissions: list[ScopePermissionData] = field(default_factory=list)
22-
object_permissions: list[ObjectPermissionData] = field(default_factory=list)
24+
permission_groups: list[PermissionGroupCreatorBeforeRoleCreation] = field(default_factory=list)
25+
object_permissions: list[ObjectPermissionCreateInputBeforeRoleCreation] = field(
26+
default_factory=list
27+
)
2328

2429

2530
@dataclass

src/ai/backend/manager/data/permission/scope_permission.py

Lines changed: 0 additions & 39 deletions
This file was deleted.
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import uuid
2+
from typing import Optional, cast
3+
4+
import sqlalchemy as sa
5+
from sqlalchemy.ext.asyncio import AsyncSession as SASession
6+
from sqlalchemy.orm import contains_eager
7+
8+
from ...data.permission.id import ObjectId
9+
from ...data.permission.role import (
10+
RoleCreateInput,
11+
RoleDeleteInput,
12+
RoleUpdateInput,
13+
UserRoleAssignmentInput,
14+
)
15+
from ...data.permission.status import (
16+
RoleStatus,
17+
)
18+
from ...errors.common import ObjectNotFound
19+
from ...models.rbac_models.association_scopes_entities import AssociationScopesEntitiesRow
20+
from ...models.rbac_models.permission.object_permission import ObjectPermissionRow
21+
from ...models.rbac_models.permission.permission import PermissionRow
22+
from ...models.rbac_models.permission.permission_group import PermissionGroupRow
23+
from ...models.rbac_models.role import RoleRow
24+
from ...models.rbac_models.user_role import UserRoleRow
25+
from ...models.utils import ExtendedAsyncSAEngine
26+
27+
28+
class PermissionControllerDBSource:
29+
_db: ExtendedAsyncSAEngine
30+
31+
def __init__(self, db: ExtendedAsyncSAEngine) -> None:
32+
self._db = db
33+
34+
async def create_role(self, data: RoleCreateInput) -> RoleRow:
35+
"""
36+
Create a new role in the database.
37+
38+
Returns the ID of the created role.
39+
"""
40+
async with self._db.begin_session() as db_session:
41+
role_row = RoleRow.from_input(data)
42+
db_session.add(role_row) # type: ignore[arg-type]
43+
await db_session.flush()
44+
role_id = role_row.id
45+
for permission_group in data.permission_groups:
46+
permission_group_row = PermissionGroupRow.from_input(
47+
permission_group.to_input(role_id)
48+
)
49+
db_session.add(permission_group_row) # type: ignore[arg-type]
50+
for object_permission in data.object_permissions:
51+
object_permission_row = ObjectPermissionRow.from_input(
52+
object_permission.to_input(role_id)
53+
)
54+
db_session.add(object_permission_row) # type: ignore[arg-type]
55+
await db_session.flush()
56+
await db_session.refresh(role_row)
57+
return role_row
58+
59+
async def _get_role(self, role_id: uuid.UUID, db_session: SASession) -> Optional[RoleRow]:
60+
stmt = sa.select(RoleRow).where(RoleRow.id == role_id)
61+
role_row = await db_session.scalar(stmt)
62+
return cast(Optional[RoleRow], role_row)
63+
64+
async def update_role(self, data: RoleUpdateInput) -> RoleRow:
65+
to_update = data.fields_to_update()
66+
async with self._db.begin_session() as db_session:
67+
stmt = sa.update(RoleRow).where(RoleRow.id == data.id).values(**to_update)
68+
await db_session.execute(stmt)
69+
role_row = await self._get_role(data.id, db_session)
70+
if role_row is None:
71+
raise ObjectNotFound(f"Role with ID {data.id} does not exist.")
72+
return role_row
73+
74+
async def delete_role(self, data: RoleDeleteInput) -> RoleRow:
75+
async with self._db.begin_session() as db_session:
76+
role_row = await self._get_role(data.id, db_session)
77+
if role_row is None:
78+
raise ObjectNotFound(f"Role with ID {data.id} does not exist.")
79+
role_row.status = RoleStatus.DELETED
80+
await db_session.flush()
81+
await db_session.refresh(role_row)
82+
return role_row
83+
84+
async def assign_role(self, data: UserRoleAssignmentInput) -> UserRoleRow:
85+
async with self._db.begin_session() as db_session:
86+
user_role_row = UserRoleRow.from_input(data)
87+
db_session.add(user_role_row) # type: ignore[arg-type]
88+
await db_session.flush()
89+
await db_session.refresh(user_role_row)
90+
return user_role_row
91+
92+
async def get_role(self, role_id: uuid.UUID) -> Optional[RoleRow]:
93+
async with self._db.begin_readonly_session() as db_session:
94+
result = await self._get_role(role_id, db_session)
95+
if result is None:
96+
return None
97+
return result
98+
99+
async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]:
100+
async with self._db.begin_readonly_session() as db_session:
101+
j = (
102+
sa.join(
103+
RoleRow,
104+
UserRoleRow,
105+
RoleRow.id == UserRoleRow.role_id,
106+
)
107+
.join(
108+
ObjectPermissionRow,
109+
RoleRow.id == ObjectPermissionRow.role_id,
110+
)
111+
.join(
112+
PermissionGroupRow,
113+
RoleRow.id == PermissionGroupRow.role_id,
114+
)
115+
.join(
116+
PermissionRow,
117+
PermissionGroupRow.id == PermissionRow.permission_group_id,
118+
)
119+
)
120+
stmt = (
121+
sa.select(RoleRow)
122+
.select_from(j)
123+
.where(UserRoleRow.user_id == user_id)
124+
.options(
125+
contains_eager(RoleRow.permission_group_rows).options(
126+
contains_eager(PermissionGroupRow.permission_rows)
127+
),
128+
contains_eager(RoleRow.object_permission_rows),
129+
)
130+
)
131+
132+
result = await db_session.scalars(stmt)
133+
return result.all()
134+
135+
async def get_entity_mapped_scopes(
136+
self, target_object_id: ObjectId
137+
) -> list[AssociationScopesEntitiesRow]:
138+
async with self._db.begin_readonly_session() as db_session:
139+
stmt = sa.select(AssociationScopesEntitiesRow.scope_id).where(
140+
AssociationScopesEntitiesRow.entity_id == target_object_id.entity_id,
141+
AssociationScopesEntitiesRow.entity_type == target_object_id.entity_type.value,
142+
)
143+
result = await db_session.scalars(stmt)
144+
return result.all()

0 commit comments

Comments
 (0)