|
| 1 | +import uuid |
| 2 | +from typing import Optional, cast |
| 3 | + |
| 4 | +import sqlalchemy as sa |
| 5 | +from sqlalchemy.ext.asyncio import AsyncSession as SASession |
| 6 | +from sqlalchemy.orm import contains_eager |
| 7 | + |
| 8 | +from ...data.permission.id import ObjectId |
| 9 | +from ...data.permission.role import ( |
| 10 | + RoleCreateInput, |
| 11 | + RoleDeleteInput, |
| 12 | + RoleUpdateInput, |
| 13 | + UserRoleAssignmentInput, |
| 14 | +) |
| 15 | +from ...data.permission.status import ( |
| 16 | + RoleStatus, |
| 17 | +) |
| 18 | +from ...errors.common import ObjectNotFound |
| 19 | +from ...models.rbac_models.association_scopes_entities import AssociationScopesEntitiesRow |
| 20 | +from ...models.rbac_models.permission.object_permission import ObjectPermissionRow |
| 21 | +from ...models.rbac_models.permission.permission import PermissionRow |
| 22 | +from ...models.rbac_models.permission.permission_group import PermissionGroupRow |
| 23 | +from ...models.rbac_models.role import RoleRow |
| 24 | +from ...models.rbac_models.user_role import UserRoleRow |
| 25 | +from ...models.utils import ExtendedAsyncSAEngine |
| 26 | + |
| 27 | + |
| 28 | +class PermissionControllerDBSource: |
| 29 | + _db: ExtendedAsyncSAEngine |
| 30 | + |
| 31 | + def __init__(self, db: ExtendedAsyncSAEngine) -> None: |
| 32 | + self._db = db |
| 33 | + |
| 34 | + async def create_role(self, data: RoleCreateInput) -> RoleRow: |
| 35 | + """ |
| 36 | + Create a new role in the database. |
| 37 | +
|
| 38 | + Returns the ID of the created role. |
| 39 | + """ |
| 40 | + async with self._db.begin_session() as db_session: |
| 41 | + role_row = RoleRow.from_input(data) |
| 42 | + db_session.add(role_row) # type: ignore[arg-type] |
| 43 | + await db_session.flush() |
| 44 | + role_id = role_row.id |
| 45 | + for permission_group in data.permission_groups: |
| 46 | + permission_group_row = PermissionGroupRow.from_input( |
| 47 | + permission_group.to_input(role_id) |
| 48 | + ) |
| 49 | + db_session.add(permission_group_row) # type: ignore[arg-type] |
| 50 | + for object_permission in data.object_permissions: |
| 51 | + object_permission_row = ObjectPermissionRow.from_input( |
| 52 | + object_permission.to_input(role_id) |
| 53 | + ) |
| 54 | + db_session.add(object_permission_row) # type: ignore[arg-type] |
| 55 | + await db_session.flush() |
| 56 | + await db_session.refresh(role_row) |
| 57 | + return role_row |
| 58 | + |
| 59 | + async def _get_role(self, role_id: uuid.UUID, db_session: SASession) -> Optional[RoleRow]: |
| 60 | + stmt = sa.select(RoleRow).where(RoleRow.id == role_id) |
| 61 | + role_row = await db_session.scalar(stmt) |
| 62 | + return cast(Optional[RoleRow], role_row) |
| 63 | + |
| 64 | + async def update_role(self, data: RoleUpdateInput) -> RoleRow: |
| 65 | + to_update = data.fields_to_update() |
| 66 | + async with self._db.begin_session() as db_session: |
| 67 | + stmt = sa.update(RoleRow).where(RoleRow.id == data.id).values(**to_update) |
| 68 | + await db_session.execute(stmt) |
| 69 | + role_row = await self._get_role(data.id, db_session) |
| 70 | + if role_row is None: |
| 71 | + raise ObjectNotFound(f"Role with ID {data.id} does not exist.") |
| 72 | + return role_row |
| 73 | + |
| 74 | + async def delete_role(self, data: RoleDeleteInput) -> RoleRow: |
| 75 | + async with self._db.begin_session() as db_session: |
| 76 | + role_row = await self._get_role(data.id, db_session) |
| 77 | + if role_row is None: |
| 78 | + raise ObjectNotFound(f"Role with ID {data.id} does not exist.") |
| 79 | + role_row.status = RoleStatus.DELETED |
| 80 | + await db_session.flush() |
| 81 | + await db_session.refresh(role_row) |
| 82 | + return role_row |
| 83 | + |
| 84 | + async def assign_role(self, data: UserRoleAssignmentInput) -> UserRoleRow: |
| 85 | + async with self._db.begin_session() as db_session: |
| 86 | + user_role_row = UserRoleRow.from_input(data) |
| 87 | + db_session.add(user_role_row) # type: ignore[arg-type] |
| 88 | + await db_session.flush() |
| 89 | + await db_session.refresh(user_role_row) |
| 90 | + return user_role_row |
| 91 | + |
| 92 | + async def get_role(self, role_id: uuid.UUID) -> Optional[RoleRow]: |
| 93 | + async with self._db.begin_readonly_session() as db_session: |
| 94 | + result = await self._get_role(role_id, db_session) |
| 95 | + if result is None: |
| 96 | + return None |
| 97 | + return result |
| 98 | + |
| 99 | + async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]: |
| 100 | + async with self._db.begin_readonly_session() as db_session: |
| 101 | + j = ( |
| 102 | + sa.join( |
| 103 | + RoleRow, |
| 104 | + UserRoleRow, |
| 105 | + RoleRow.id == UserRoleRow.role_id, |
| 106 | + ) |
| 107 | + .join( |
| 108 | + ObjectPermissionRow, |
| 109 | + RoleRow.id == ObjectPermissionRow.role_id, |
| 110 | + ) |
| 111 | + .join( |
| 112 | + PermissionGroupRow, |
| 113 | + RoleRow.id == PermissionGroupRow.role_id, |
| 114 | + ) |
| 115 | + .join( |
| 116 | + PermissionRow, |
| 117 | + PermissionGroupRow.id == PermissionRow.permission_group_id, |
| 118 | + ) |
| 119 | + ) |
| 120 | + stmt = ( |
| 121 | + sa.select(RoleRow) |
| 122 | + .select_from(j) |
| 123 | + .where(UserRoleRow.user_id == user_id) |
| 124 | + .options( |
| 125 | + contains_eager(RoleRow.permission_group_rows).options( |
| 126 | + contains_eager(PermissionGroupRow.permission_rows) |
| 127 | + ), |
| 128 | + contains_eager(RoleRow.object_permission_rows), |
| 129 | + ) |
| 130 | + ) |
| 131 | + |
| 132 | + result = await db_session.scalars(stmt) |
| 133 | + return result.all() |
| 134 | + |
| 135 | + async def get_entity_mapped_scopes( |
| 136 | + self, target_object_id: ObjectId |
| 137 | + ) -> list[AssociationScopesEntitiesRow]: |
| 138 | + async with self._db.begin_readonly_session() as db_session: |
| 139 | + stmt = sa.select(AssociationScopesEntitiesRow.scope_id).where( |
| 140 | + AssociationScopesEntitiesRow.entity_id == target_object_id.entity_id, |
| 141 | + AssociationScopesEntitiesRow.entity_type == target_object_id.entity_type.value, |
| 142 | + ) |
| 143 | + result = await db_session.scalars(stmt) |
| 144 | + return result.all() |
0 commit comments