Skip to content

Commit a7ec9ff

Browse files
committed
✨ Add support for hybrid_property
1 parent 75ce455 commit a7ec9ff

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

sqlmodel/main.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sqlalchemy import Boolean, Column, Date, DateTime
3535
from sqlalchemy import Enum as sa_Enum
3636
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
37+
from sqlalchemy.ext.hybrid import hybrid_property
3738
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
3839
from sqlalchemy.orm.attributes import set_attribute
3940
from sqlalchemy.orm.decl_api import DeclarativeMeta
@@ -207,6 +208,7 @@ def Relationship(
207208
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
208209
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
209210
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
211+
__sqlalchemy_constructs__: Dict[str, Any]
210212
__config__: Type[BaseConfig]
211213
__fields__: Dict[str, ModelField]
212214

@@ -232,6 +234,7 @@ def __new__(
232234
**kwargs: Any,
233235
) -> Any:
234236
relationships: Dict[str, RelationshipInfo] = {}
237+
sqlalchemy_constructs = {}
235238
dict_for_pydantic = {}
236239
original_annotations = resolve_annotations(
237240
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
@@ -241,6 +244,8 @@ def __new__(
241244
for k, v in class_dict.items():
242245
if isinstance(v, RelationshipInfo):
243246
relationships[k] = v
247+
elif isinstance(v, hybrid_property):
248+
sqlalchemy_constructs[k] = v
244249
else:
245250
dict_for_pydantic[k] = v
246251
for k, v in original_annotations.items():
@@ -253,6 +258,7 @@ def __new__(
253258
"__weakref__": None,
254259
"__sqlmodel_relationships__": relationships,
255260
"__annotations__": pydantic_annotations,
261+
"__sqlalchemy_constructs__": sqlalchemy_constructs
256262
}
257263
# Duplicate logic from Pydantic to filter config kwargs because if they are
258264
# passed directly including the registry Pydantic will pass them over to the
@@ -276,6 +282,9 @@ def __new__(
276282
**new_cls.__annotations__,
277283
}
278284

285+
for k, v in sqlalchemy_constructs.items():
286+
setattr(new_cls, k, v)
287+
279288
def get_config(name: str) -> Any:
280289
config_class_value = getattr(new_cls.__config__, name, Undefined)
281290
if config_class_value is not Undefined:
@@ -290,8 +299,9 @@ def get_config(name: str) -> Any:
290299
# If it was passed by kwargs, ensure it's also set in config
291300
new_cls.__config__.table = config_table
292301
for k, v in new_cls.__fields__.items():
293-
col = get_column_from_field(v)
294-
setattr(new_cls, k, col)
302+
if k in sqlalchemy_constructs:
303+
continue
304+
setattr(new_cls, k, get_column_from_field(v))
295305
# Set a config flag to tell FastAPI that this should be read with a field
296306
# in orm_mode instead of preemptively converting it to a dict.
297307
# This could be done by reading new_cls.__config__.table in FastAPI, but
@@ -326,6 +336,8 @@ def __init__(
326336
if getattr(cls.__config__, "table", False) and not base_is_table:
327337
dict_used = dict_.copy()
328338
for field_name, field_value in cls.__fields__.items():
339+
if field_name in cls.__sqlalchemy_constructs__:
340+
continue
329341
dict_used[field_name] = get_column_from_field(field_value)
330342
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
331343
if rel_info.sa_relationship:

tests/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
from pydantic import BaseModel
8-
from sqlmodel import SQLModel
8+
from sqlmodel import SQLModel, create_engine
99
from sqlmodel.main import default_registry
1010

1111
top_level_path = Path(__file__).resolve().parent.parent
@@ -23,6 +23,13 @@ def clear_sqlmodel():
2323
default_registry.dispose()
2424

2525

26+
@pytest.fixture()
27+
def in_memory_engine(clear_sqlmodel):
28+
engine = create_engine("sqlite:///memory")
29+
yield engine
30+
SQLModel.metadata.drop_all(engine, checkfirst=True)
31+
32+
2633
@pytest.fixture()
2734
def cov_tmp_path(tmp_path: Path):
2835
yield tmp_path

tests/test_sqlalchemy_properties.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Optional
2+
3+
from sqlalchemy import func
4+
from sqlalchemy.ext.hybrid import hybrid_property
5+
from sqlmodel import Field, Session, SQLModel, select
6+
7+
8+
def test_hybrid_property(in_memory_engine):
9+
class Interval(SQLModel, table=True):
10+
id: Optional[int] = Field(default=None, primary_key=True)
11+
length: float
12+
13+
@hybrid_property
14+
def radius(self) -> float:
15+
return abs(self.length) / 2
16+
17+
@radius.expression
18+
def radius(cls) -> float:
19+
return func.abs(cls.length) / 2
20+
21+
class Config:
22+
arbitrary_types_allowed = True
23+
24+
SQLModel.metadata.create_all(in_memory_engine)
25+
session = Session(in_memory_engine)
26+
27+
interval = Interval(length=-2)
28+
assert interval.radius == 1
29+
30+
session.add(interval)
31+
session.commit()
32+
interval_2 = session.exec(select(Interval)).all()[0]
33+
assert interval_2.radius == 1
34+
35+
interval_3 = session.exec(select(Interval).where(Interval.radius == 1)).all()[0]
36+
assert interval_3.radius == 1
37+
38+
intervals = session.exec(select(Interval).where(Interval.radius > 1)).all()
39+
assert len(intervals) == 0
40+
41+
assert session.exec(select(Interval.radius + 1)).all()[0] == 2.0

0 commit comments

Comments
 (0)