Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hybrid_property, column_property, declared_attr #801

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 7 additions & 10 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
Representation as Representation,
)

class SQLModelConfig(BaseConfig): # type: ignore[no-redef]
class SQLModelConfig(ConfigDict): # type: ignore[no-redef]
table: Optional[bool] = None # type: ignore[misc]
registry: Optional[Any] = None # type: ignore[misc]

Expand All @@ -396,12 +396,12 @@ def set_config_value(
setattr(model.__config__, parameter, value) # type: ignore

def get_model_fields(model: InstanceOrType[BaseModel]) -> Dict[str, "FieldInfo"]:
return model.__fields__ # type: ignore
return model.model_fields

def get_fields_set(
object: InstanceOrType["SQLModel"],
) -> Union[Set[str], Callable[[BaseModel], Set[str]]]:
return object.__fields_set__
) -> Union[Set[str], property]:
return object.model_fields_set

def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None:
object.__setattr__(new_object, "__fields_set__", set())
Expand Down Expand Up @@ -472,7 +472,7 @@ def _calculate_keys(
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
return (
self.__fields__.keys() # noqa
self.model_fields.keys() # noqa
) # | self.__sqlmodel_relationships__.keys()

keys: AbstractSet[str]
Expand All @@ -485,7 +485,7 @@ def _calculate_keys(
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
keys = (
self.__fields__.keys() # noqa
self.model_fields.keys() # noqa
) # | self.__sqlmodel_relationships__.keys()
if include is not None:
keys &= include.keys()
Expand Down Expand Up @@ -547,10 +547,7 @@ def sqlmodel_validate(
def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None:
values, fields_set, validation_error = validate_model(self.__class__, data)
# Only raise errors if not a SQLModel model
if (
not is_table_model_class(self.__class__) # noqa
and validation_error
):
if not is_table_model_class(self.__class__) and validation_error: # noqa
raise validation_error
if not is_table_model_class(self.__class__):
object.__setattr__(self, "__dict__", values)
Expand Down
39 changes: 37 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
TYPE_CHECKING,
AbstractSet,
Annotated,
Any,
Callable,
ClassVar,
Expand All @@ -22,6 +23,7 @@
TypeVar,
Union,
cast,
get_args,
overload,
)

Expand All @@ -37,10 +39,13 @@
Integer,
Interval,
Numeric,
Table,
inspect,
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.ext.hybrid import hybrid_method, hybrid_property
from sqlalchemy.orm import (
ColumnProperty,
Mapped,
RelationshipProperty,
declared_attr,
Expand Down Expand Up @@ -91,6 +96,9 @@
_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None]
SQLAlchemyConstruct = Union[
hybrid_property, hybrid_method, ColumnProperty, declared_attr
]


def __dataclass_transform__(
Expand Down Expand Up @@ -396,10 +404,12 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__sqlalchemy_constructs__: Dict[str, SQLAlchemyConstruct]
model_config: SQLModelConfig
model_fields: Dict[str, FieldInfo]
__config__: Type[SQLModelConfig]
__fields__: Dict[str, ModelField] # type: ignore[assignment]
__table__: Table

# Replicate SQLAlchemy
def __setattr__(cls, name: str, value: Any) -> None:
Expand All @@ -423,13 +433,18 @@ def __new__(
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
sqlalchemy_constructs: Dict[str, SQLAlchemyConstruct] = {}
dict_for_pydantic = {}
original_annotations = get_annotations(class_dict)
pydantic_annotations = {}
relationship_annotations = {}
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
elif isinstance(
v, (hybrid_property, hybrid_method, ColumnProperty, declared_attr)
):
sqlalchemy_constructs[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
Expand All @@ -442,6 +457,7 @@ def __new__(
"__weakref__": None,
"__sqlmodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
"__sqlalchemy_constructs__": sqlalchemy_constructs,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
Expand All @@ -463,6 +479,11 @@ def __new__(
**new_cls.__annotations__,
}

# We did not provide the sqlalchemy constructs to Pydantic's new function above
# so that they wouldn't be modified. Instead we set them directly to the class below:
for k, v in sqlalchemy_constructs.items():
setattr(new_cls, k, v)

def get_config(name: str) -> Any:
config_class_value = get_config_value(
model=new_cls, parameter=name, default=Undefined
Expand All @@ -479,6 +500,8 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
if k in sqlalchemy_constructs:
continue
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
Expand All @@ -501,6 +524,9 @@ def get_config(name: str) -> Any:
setattr(new_cls, "_sa_registry", config_registry) # noqa: B010
setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010
setattr(new_cls, "__abstract__", True) # noqa: B010
setattr(new_cls, "__pydantic_private__", {}) # noqa: B010
setattr(new_cls, "__pydantic_extra__", {}) # noqa: B010

return new_cls

# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
Expand All @@ -514,6 +540,9 @@ def __init__(
base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table:
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_name in cls.__sqlalchemy_constructs__:
# Skip hybrid properties
continue
50Bytes-dev marked this conversation as resolved.
Show resolved Hide resolved
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
# over anything else, use that and continue with the next attribute
Expand Down Expand Up @@ -559,6 +588,10 @@ def __init__(
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)


def is_annotated_type(type_: Any) -> bool:
return get_origin(type_) is Annotated


def get_sqlalchemy_type(field: Any) -> Any:
if IS_PYDANTIC_V2:
field_info = field
Expand All @@ -572,6 +605,8 @@ def get_sqlalchemy_type(field: Any) -> Any:
metadata = get_field_metadata(field)

# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if is_annotated_type(type_):
type_ = get_args(type_)[0]
if issubclass(type_, Enum):
return sa_Enum(type_)
if issubclass(type_, str):
Expand Down Expand Up @@ -686,7 +721,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six

if IS_PYDANTIC_V2:
model_config = SQLModelConfig(from_attributes=True)
model_config = SQLModelConfig(from_attributes=True, use_enum_values=True)
else:

class Config:
Expand Down Expand Up @@ -778,7 +813,7 @@ def model_dump(
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
warnings: bool = True,
serialize_as_any: bool = False,
) -> Dict[str, Any]:
if PYDANTIC_VERSION >= "2.7.0":
Expand Down
80 changes: 80 additions & 0 deletions tests/test_column_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import List, Optional

from sqlalchemy import case, create_engine, func
from sqlalchemy.orm import column_property, declared_attr
from sqlmodel import Field, Relationship, Session, SQLModel, select


def test_query(clear_sqlmodel):
class Item(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
value: float
hero_id: int = Field(foreign_key="hero.id")
hero: "Hero" = Relationship(back_populates="items")

class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
items: List[Item] = Relationship(back_populates="hero")

@declared_attr
def total_items(cls):
return column_property(cls._total_items_expression())

@classmethod
def _total_items_expression(cls):
return (
select(func.coalesce(func.sum(Item.value), 0))
.where(Item.hero_id == cls.id)
.correlate_except(Item)
.label("total_items")
)

@declared_attr
def status(cls):
return column_property(
select(
case(
(cls._total_items_expression() > 0, "active"), else_="inactive"
)
).scalar_subquery()
)

hero_1 = Hero(name="Deadpond")
hero_2 = Hero(name="Spiderman")

engine = create_engine("sqlite://")

SQLModel.metadata.create_all(engine)
with Session(engine) as session:
session.add(hero_1)
session.add(hero_2)
session.commit()
session.refresh(hero_1)
session.refresh(hero_2)

item_1 = Item(value=1.0, hero_id=hero_1.id)
item_2 = Item(value=2.0, hero_id=hero_1.id)

with Session(engine) as session:
session.add(item_1)
session.add(item_2)
session.commit()
session.refresh(item_1)
session.refresh(item_2)

with Session(engine) as session:
hero_statement = select(Hero).where(Hero.total_items > 0.0)
hero = session.exec(hero_statement).first()
assert hero.name == "Deadpond"
assert hero.total_items == 3.0
assert hero.status == "active"

with Session(engine) as session:
hero_statement = select(Hero).where(
Hero.status == "inactive",
)
hero = session.exec(hero_statement).first()
assert hero.name == "Spiderman"
assert hero.total_items == 0.0
assert hero.status == "inactive"
13 changes: 13 additions & 0 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,19 @@ def sqlite_dump(sql: TypeEngine, *args, **kwargs):
sqlite_engine = create_mock_engine("sqlite://", sqlite_dump)


def _reset_metadata():
SQLModel.metadata.clear()

class FlatModel(SQLModel, table=True):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum1

class InheritModel(BaseModel, table=True):
pass


def test_postgres_ddl_sql(capsys):
_reset_metadata()
SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False)

captured = capsys.readouterr()
Expand All @@ -67,6 +79,7 @@ def test_postgres_ddl_sql(capsys):


def test_sqlite_ddl_sql(capsys):
_reset_metadata()
SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False)

captured = capsys.readouterr()
Expand Down