16
16
17
17
@dataclass_transform (kw_only_default = True , field_specifiers = (PydanticField ,))
18
18
class DBModelMetaclass (_model_construction .ModelMetaclass ):
19
+ """
20
+ Metaclass for database model classes that provides automatic field tracking and SQL query generation.
21
+ Extends Pydantic's model metaclass to add database-specific functionality.
22
+
23
+ This metaclass provides:
24
+ - Automatic field to SQL column mapping
25
+ - Dynamic field access that returns query-compatible field definitions
26
+ - Registry tracking of all database model classes
27
+ - Support for generic model instantiation
28
+
29
+ ```python {{sticky: True}}
30
+ class User(TableBase): # Uses DBModelMetaclass
31
+ id: int = Field(primary_key=True)
32
+ name: str
33
+ email: str | None
34
+
35
+ # Fields can be accessed for queries
36
+ User.id # Returns DBFieldClassDefinition
37
+ User.name # Returns DBFieldClassDefinition
38
+
39
+ # Metaclass handles model registration
40
+ registered_models = DBModelMetaclass.get_registry()
41
+ ```
42
+ """
43
+
19
44
_registry : list [Type ["TableBase" ]] = []
20
- # {class: kwargs}
21
45
_cached_args : dict [Type ["TableBase" ], dict [str , Any ]] = {}
46
+ is_constructing : bool = False
22
47
23
- def __new__ (
24
- mcs , name : str , bases : tuple , namespace : dict [str , Any ], ** kwargs : Any
25
- ) -> type :
48
+ def __new__ (mcs , name , bases , namespace , ** kwargs ):
49
+ """
50
+ Create a new database model class with proper field tracking.
51
+ Handles registration of the model and processes any table-specific arguments.
52
+ """
26
53
raw_kwargs = {** kwargs }
27
54
28
55
mcs .is_constructing = True
@@ -59,8 +86,15 @@ def __new__(
59
86
return cls
60
87
61
88
def __getattr__ (self , key : str ) -> Any :
62
- # Inspired by the approach in our render logic
63
- # https://github.com/piercefreeman/mountaineer/blob/fdda3a58c0fafebb43a58b4f3d410dbf44302fd6/mountaineer/render.py#L252
89
+ """
90
+ Provides dynamic access to model fields as query-compatible definitions.
91
+ When accessing an undefined attribute, checks if it's a model field and returns
92
+ a DBFieldClassDefinition if it is.
93
+
94
+ :param key: The attribute name to access
95
+ :return: Field definition or raises AttributeError
96
+ :raises AttributeError: If the attribute doesn't exist and isn't a model field
97
+ """
64
98
if self .is_constructing :
65
99
return super ().__getattr__ (key ) # type: ignore
66
100
@@ -78,16 +112,26 @@ def __getattr__(self, key: str) -> Any:
78
112
raise
79
113
80
114
@classmethod
81
- def get_registry (cls ):
115
+ def get_registry (cls ) -> list [Type ["TableBase" ]]:
116
+ """
117
+ Get the set of all registered database model classes.
118
+
119
+ :return: Set of registered TableBase classes
120
+ """
82
121
return cls ._registry
83
122
84
123
@classmethod
85
- def _extract_kwarg (cls , kwargs : dict [str , Any ], key : str , default : Any = None ):
124
+ def _extract_kwarg (
125
+ cls , kwargs : dict [str , Any ], key : str , default : Any = None
126
+ ) -> Any :
86
127
"""
87
- Kwarg extraction that supports standard instantiation and pydantic's approach
88
- for Generic models where it hydrates a fully new class in memory with the type
89
- annotations set to generic values.
128
+ Extract a keyword argument from either standard kwargs or pydantic generic metadata.
129
+ Handles both normal instantiation and pydantic's generic model instantiation.
90
130
131
+ :param kwargs: Dictionary of keyword arguments
132
+ :param key: Key to extract
133
+ :param default: Default value if key not found
134
+ :return: Extracted value or default
91
135
"""
92
136
if key in kwargs :
93
137
return kwargs .pop (key )
@@ -101,69 +145,186 @@ def _extract_kwarg(cls, kwargs: dict[str, Any], key: str, default: Any = None):
101
145
102
146
@property
103
147
def model_fields (self ) -> dict [str , DBFieldInfo ]: # type: ignore
104
- # model_fields must be reimplemented in our custom metaclass, otherwise
105
- # clients will get the super typehinting signature when they try
106
- # to access Model.model_fields. This overrides the ClassVar typehint
107
- # that's placed in the TableBase itself.
148
+ """
149
+ Get the dictionary of model fields and their definitions.
150
+ Overrides the ClassVar typehint from TableBase for proper typing.
151
+
152
+ :return: Dictionary of field names to field definitions
153
+ """
108
154
return super ().model_fields # type: ignore
109
155
110
156
111
157
class UniqueConstraint (BaseModel ):
158
+ """
159
+ Represents a UNIQUE constraint in a database table.
160
+ Ensures that the specified combination of columns contains unique values across all rows.
161
+
162
+ ```python {{sticky: True}}
163
+ class User(TableBase):
164
+ email: str
165
+ tenant_id: int
166
+
167
+ table_args = [
168
+ UniqueConstraint(columns=["email", "tenant_id"])
169
+ ]
170
+ ```
171
+ """
172
+
112
173
columns : list [str ]
174
+ """
175
+ List of column names that should have unique values
176
+ """
113
177
114
178
115
179
class IndexConstraint (BaseModel ):
180
+ """
181
+ Represents an INDEX on one or more columns in a database table.
182
+ Improves query performance for the specified columns.
183
+
184
+ ```python {{sticky: True}}
185
+ class User(TableBase):
186
+ email: str
187
+ last_login: datetime
188
+
189
+ table_args = [
190
+ IndexConstraint(columns=["last_login"])
191
+ ]
192
+ ```
193
+ """
194
+
116
195
columns : list [str ]
196
+ """
197
+ List of column names to create an index on
198
+ """
117
199
118
200
119
201
INTERNAL_TABLE_FIELDS = ["modified_attrs" ]
120
202
121
203
122
204
class TableBase (BaseModel , metaclass = DBModelMetaclass ):
205
+ """
206
+ Base class for all database table models.
207
+ Provides the foundation for defining database tables using Python classes with
208
+ type hints and field definitions.
209
+
210
+ Features:
211
+ - Automatic table name generation from class name
212
+ - Support for custom table names
213
+ - Tracking of modified fields for efficient updates
214
+ - Support for unique constraints and indexes
215
+ - Integration with Pydantic for validation
216
+
217
+ ```python {{sticky: True}}
218
+ class User(TableBase):
219
+ # Custom table name (optional)
220
+ table_name = "users"
221
+
222
+ # Fields with types and constraints
223
+ id: int = Field(primary_key=True)
224
+ email: str = Field(unique=True)
225
+ name: str
226
+ is_active: bool = Field(default=True)
227
+
228
+ # Table-level constraints
229
+ table_args = [
230
+ UniqueConstraint(columns=["email"]),
231
+ IndexConstraint(columns=["name"])
232
+ ]
233
+
234
+ # Usage in queries
235
+ query = select(User).where(User.is_active == True)
236
+ users = await conn.execute(query)
237
+ ```
238
+ """
239
+
123
240
if TYPE_CHECKING :
124
241
model_fields : ClassVar [dict [str , DBFieldInfo ]] # type: ignore
125
242
126
243
table_name : ClassVar [str ] = PydanticUndefined # type: ignore
244
+ """
245
+ Optional custom name for the table
246
+ """
247
+
127
248
table_args : ClassVar [list [UniqueConstraint | IndexConstraint ]] = PydanticUndefined # type: ignore
249
+ """
250
+ Table constraints and indexes
251
+ """
128
252
129
253
# Private methods
130
254
modified_attrs : dict [str , Any ] = Field (default_factory = dict , exclude = True )
255
+ """
256
+ Dictionary of modified field values since instantiation or the last clear_modified_attributes() call.
257
+ Used to construct differential update queries.
258
+ """
259
+
260
+ def __setattr__ (self , name : str , value : Any ) -> None :
261
+ """
262
+ Track modified attributes when fields are updated.
263
+ This allows for efficient database updates by only updating changed fields.
131
264
132
- def __setattr__ (self , name , value ):
265
+ :param name: Attribute name
266
+ :param value: New value
267
+ """
133
268
if name in self .model_fields :
134
269
self .modified_attrs [name ] = value
135
270
super ().__setattr__ (name , value )
136
271
137
272
def get_modified_attributes (self ) -> dict [str , Any ]:
273
+ """
274
+ Get the dictionary of attributes that have been modified since instantiation
275
+ or the last clear_modified_attributes() call.
276
+
277
+ :return: Dictionary of modified attribute names and their values
278
+ """
138
279
return self .modified_attrs
139
280
140
- def clear_modified_attributes (self ):
281
+ def clear_modified_attributes (self ) -> None :
282
+ """
283
+ Clear the tracking of modified attributes.
284
+ Typically called after successfully saving changes to the database.
285
+ """
141
286
self .modified_attrs .clear ()
142
287
143
288
@classmethod
144
- def get_table_name (cls ):
289
+ def get_table_name (cls ) -> str :
290
+ """
291
+ Get the table name for this model.
292
+ Uses the custom table_name if set, otherwise converts the class name to lowercase.
293
+
294
+ :return: Table name to use in SQL queries
295
+ """
145
296
if cls .table_name == PydanticUndefined :
146
297
return cls .__name__ .lower ()
147
298
return cls .table_name
148
299
149
300
@classmethod
150
- def get_client_fields (cls ):
301
+ def get_client_fields (cls ) -> dict [str , DBFieldInfo ]:
302
+ """
303
+ Get all fields that should be exposed to clients.
304
+ Excludes internal fields used for model functionality.
305
+
306
+ :return: Dictionary of field names to field definitions
307
+ """
151
308
return {
152
309
field : info
153
310
for field , info in cls .model_fields .items ()
154
311
if field not in INTERNAL_TABLE_FIELDS
155
312
}
156
313
157
314
@classmethod
158
- def select_fields (cls ):
315
+ def select_fields (cls ) -> QueryLiteral :
159
316
"""
160
- Returns a query selectable string that can be used to select all fields
161
- from this model. This is the format that needs to be passed to our parser
162
- to serialize the raw postgres field values as TableBase objects .
317
+ Generate a SQL-safe string for selecting all fields from this table.
318
+ The output format is "{table_name}.{field_name} as {table_name}_{field_name}"
319
+ for each field, which ensures proper field name resolution in complex queries .
163
320
164
- The exact format is formatted as:
165
- "{table_name}.{field_name} as {table_name}_{field_name}".
321
+ :return: SQL-safe field selection string
166
322
323
+ ```python {{sticky: True}}
324
+ # For a User class with fields 'id' and 'name':
325
+ User.select_fields()
326
+ # Returns: '"users"."id" as "users_id", "users"."name" as "users_name"'
327
+ ```
167
328
"""
168
329
table_token = QueryIdentifier (cls .get_table_name ())
169
330
select_fields : list [str ] = []
0 commit comments