Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get primary key fields from mapper #508

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ def __init__(
self.fields = (converter or ModelConverter()).convert_fields_list(
fields=self.fields, model=self.model, mapper=mapper
)
self._setup_primary_key()
self._setup_primary_key(mapper)
self.exclude_fields_from_list = normalize_list(self.exclude_fields_from_list) # type: ignore
self.exclude_fields_from_detail = normalize_list(self.exclude_fields_from_detail) # type: ignore
self.exclude_fields_from_create = normalize_list(self.exclude_fields_from_create) # type: ignore
self.exclude_fields_from_detail = normalize_list(
self.exclude_fields_from_detail
) # type: ignore
self.exclude_fields_from_create = normalize_list(
self.exclude_fields_from_create
) # type: ignore
self.exclude_fields_from_edit = normalize_list(self.exclude_fields_from_edit) # type: ignore
_default_list = [
field.name
Expand All @@ -136,19 +140,20 @@ def __init__(
)
super().__init__()

def _setup_primary_key(self) -> None:
def _setup_primary_key(self, mapper: Mapper) -> None:
# Detect the primary key attribute(s) of the model
_pk_attrs = []

self._pk_column: Union[
Tuple[InstrumentedAttribute, ...], InstrumentedAttribute
] = ()
self._pk_coerce: Union[Tuple[type, ...], type] = ()
for key in self.model.__dict__:
attr = getattr(self.model, key)
if isinstance(attr, InstrumentedAttribute) and getattr(
attr, "primary_key", False
):
_pk_attrs.append(key)

# mapper._primary_key_propkeys but then ordered by occurrence in the model
pks_by_table = [mapper._pks_by_table[table] for table in mapper.tables] # type: ignore[attr-defined]
_pk_attrs: List[str] = []
for table in pks_by_table:
_pk_attrs += [mapper._columntoproperty[c].key for c in table] # type: ignore[attr-defined]

if len(_pk_attrs) > 1:
self._pk_column = tuple(getattr(self.model, attr) for attr in _pk_attrs)
self._pk_coerce = tuple(
Expand Down Expand Up @@ -301,7 +306,8 @@ async def find_all(
if isinstance(session, AsyncSession):
return (await session.execute(stmt)).scalars().unique().all()
return (
(await anyio.to_thread.run_sync(session.execute, stmt)) # type: ignore[arg-type]
(await anyio.to_thread.run_sync(session.execute, stmt))
# type: ignore[arg-type]
.scalars()
.unique()
.all()
Expand All @@ -327,7 +333,9 @@ async def find_by_pk(self, request: Request, pk: Any) -> Any:
== (_pk == "True") # to avoid bool("False") which is True
)
for _pk_col, _coerce, _pk in zip(
self._pk_column, self._pk_coerce, iterdecode(pk) # type: ignore[type-var,arg-type]
self._pk_column,
self._pk_coerce,
iterdecode(pk), # type: ignore[type-var,arg-type]
)
)
else:
Expand All @@ -340,7 +348,8 @@ async def find_by_pk(self, request: Request, pk: Any) -> Any:
if isinstance(session, AsyncSession):
return (await session.execute(stmt)).scalars().unique().one_or_none()
return (
(await anyio.to_thread.run_sync(session.execute, stmt)) # type: ignore[arg-type]
(await anyio.to_thread.run_sync(session.execute, stmt))
# type: ignore[arg-type]
.scalars()
.unique()
.one_or_none()
Expand Down Expand Up @@ -376,7 +385,8 @@ async def _exec_find_by_pks(
if isinstance(session, AsyncSession):
return (await session.execute(stmt)).scalars().unique().all()
return (
(await anyio.to_thread.run_sync(session.execute, stmt)) # type: ignore[arg-type]
(await anyio.to_thread.run_sync(session.execute, stmt))
# type: ignore[arg-type]
.scalars()
.unique()
.all()
Expand Down Expand Up @@ -411,7 +421,8 @@ async def _get_multiple_pks_in_clause(
tuple(
(_coerce(_pk) if _coerce is not bool else _pk == "True")
for _coerce, _pk in zip(
self._pk_coerce, decoded_pk # type: ignore[type-var,arg-type]
self._pk_coerce,
decoded_pk, # type: ignore[type-var,arg-type]
)
)
for decoded_pk in decoded_pks
Expand All @@ -427,7 +438,9 @@ async def _get_multiple_pks_in_clause(
else (_pk_col == (_pk == "True"))
) # to avoid bool("False") which is True
for _pk_col, _coerce, _pk in zip(
self._pk_column, self._pk_coerce, decoded_pk # type: ignore[type-var,arg-type]
self._pk_column,
self._pk_coerce,
decoded_pk, # type: ignore[type-var,arg-type]
)
)
)
Expand Down