Skip to content

Commit ce34dda

Browse files
Merge pull request #35 from piercefreeman/feature/expand-docstrings
Expand class docstrings with typehints
2 parents 0433cb6 + b4a15a1 commit ce34dda

File tree

10 files changed

+1080
-41
lines changed

10 files changed

+1080
-41
lines changed

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,3 @@ a simple poetry install. Poetry is set up to create a dynamic `setup.py` based o
238238
```bash
239239
poetry install
240240
```
241-
242-
## TODOs
243-
244-
- [ ] Additional documentation with usage examples.

iceaxe/__tests__/test_session.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,28 @@ async def test_refresh(db_connection: DBConnection):
529529
assert user.name == "Jane Doe"
530530

531531

532+
@pytest.mark.asyncio
533+
async def test_get(db_connection: DBConnection):
534+
"""
535+
Test retrieving a single record by primary key using the get method.
536+
"""
537+
# Create a test user
538+
user = UserDemo(name="John Doe", email="[email protected]")
539+
await db_connection.insert([user])
540+
assert user.id is not None
541+
542+
# Test successful get
543+
retrieved_user = await db_connection.get(UserDemo, user.id)
544+
assert retrieved_user is not None
545+
assert retrieved_user.id == user.id
546+
assert retrieved_user.name == "John Doe"
547+
assert retrieved_user.email == "[email protected]"
548+
549+
# Test get with non-existent ID
550+
non_existent = await db_connection.get(UserDemo, 9999)
551+
assert non_existent is None
552+
553+
532554
@pytest.mark.asyncio
533555
async def test_db_connection_insert_update_enum(db_connection: DBConnection):
534556
"""

iceaxe/base.py

Lines changed: 186 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,40 @@
1616

1717
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,))
1818
class DBModelMetaclass(_model_construction.ModelMetaclass):
19+
"""
20+
Metaclass for database model classes that provides automatic field tracking and SQL query generation.
21+
Extends Pydantic's model metaclass to add database-specific functionality.
22+
23+
This metaclass provides:
24+
- Automatic field to SQL column mapping
25+
- Dynamic field access that returns query-compatible field definitions
26+
- Registry tracking of all database model classes
27+
- Support for generic model instantiation
28+
29+
```python {{sticky: True}}
30+
class User(TableBase): # Uses DBModelMetaclass
31+
id: int = Field(primary_key=True)
32+
name: str
33+
email: str | None
34+
35+
# Fields can be accessed for queries
36+
User.id # Returns DBFieldClassDefinition
37+
User.name # Returns DBFieldClassDefinition
38+
39+
# Metaclass handles model registration
40+
registered_models = DBModelMetaclass.get_registry()
41+
```
42+
"""
43+
1944
_registry: list[Type["TableBase"]] = []
20-
# {class: kwargs}
2145
_cached_args: dict[Type["TableBase"], dict[str, Any]] = {}
46+
is_constructing: bool = False
2247

23-
def __new__(
24-
mcs, name: str, bases: tuple, namespace: dict[str, Any], **kwargs: Any
25-
) -> type:
48+
def __new__(mcs, name, bases, namespace, **kwargs):
49+
"""
50+
Create a new database model class with proper field tracking.
51+
Handles registration of the model and processes any table-specific arguments.
52+
"""
2653
raw_kwargs = {**kwargs}
2754

2855
mcs.is_constructing = True
@@ -59,8 +86,15 @@ def __new__(
5986
return cls
6087

6188
def __getattr__(self, key: str) -> Any:
62-
# Inspired by the approach in our render logic
63-
# https://github.com/piercefreeman/mountaineer/blob/fdda3a58c0fafebb43a58b4f3d410dbf44302fd6/mountaineer/render.py#L252
89+
"""
90+
Provides dynamic access to model fields as query-compatible definitions.
91+
When accessing an undefined attribute, checks if it's a model field and returns
92+
a DBFieldClassDefinition if it is.
93+
94+
:param key: The attribute name to access
95+
:return: Field definition or raises AttributeError
96+
:raises AttributeError: If the attribute doesn't exist and isn't a model field
97+
"""
6498
if self.is_constructing:
6599
return super().__getattr__(key) # type: ignore
66100

@@ -78,16 +112,26 @@ def __getattr__(self, key: str) -> Any:
78112
raise
79113

80114
@classmethod
81-
def get_registry(cls):
115+
def get_registry(cls) -> list[Type["TableBase"]]:
116+
"""
117+
Get the set of all registered database model classes.
118+
119+
:return: Set of registered TableBase classes
120+
"""
82121
return cls._registry
83122

84123
@classmethod
85-
def _extract_kwarg(cls, kwargs: dict[str, Any], key: str, default: Any = None):
124+
def _extract_kwarg(
125+
cls, kwargs: dict[str, Any], key: str, default: Any = None
126+
) -> Any:
86127
"""
87-
Kwarg extraction that supports standard instantiation and pydantic's approach
88-
for Generic models where it hydrates a fully new class in memory with the type
89-
annotations set to generic values.
128+
Extract a keyword argument from either standard kwargs or pydantic generic metadata.
129+
Handles both normal instantiation and pydantic's generic model instantiation.
90130
131+
:param kwargs: Dictionary of keyword arguments
132+
:param key: Key to extract
133+
:param default: Default value if key not found
134+
:return: Extracted value or default
91135
"""
92136
if key in kwargs:
93137
return kwargs.pop(key)
@@ -101,69 +145,186 @@ def _extract_kwarg(cls, kwargs: dict[str, Any], key: str, default: Any = None):
101145

102146
@property
103147
def model_fields(self) -> dict[str, DBFieldInfo]: # type: ignore
104-
# model_fields must be reimplemented in our custom metaclass, otherwise
105-
# clients will get the super typehinting signature when they try
106-
# to access Model.model_fields. This overrides the ClassVar typehint
107-
# that's placed in the TableBase itself.
148+
"""
149+
Get the dictionary of model fields and their definitions.
150+
Overrides the ClassVar typehint from TableBase for proper typing.
151+
152+
:return: Dictionary of field names to field definitions
153+
"""
108154
return super().model_fields # type: ignore
109155

110156

111157
class UniqueConstraint(BaseModel):
158+
"""
159+
Represents a UNIQUE constraint in a database table.
160+
Ensures that the specified combination of columns contains unique values across all rows.
161+
162+
```python {{sticky: True}}
163+
class User(TableBase):
164+
email: str
165+
tenant_id: int
166+
167+
table_args = [
168+
UniqueConstraint(columns=["email", "tenant_id"])
169+
]
170+
```
171+
"""
172+
112173
columns: list[str]
174+
"""
175+
List of column names that should have unique values
176+
"""
113177

114178

115179
class IndexConstraint(BaseModel):
180+
"""
181+
Represents an INDEX on one or more columns in a database table.
182+
Improves query performance for the specified columns.
183+
184+
```python {{sticky: True}}
185+
class User(TableBase):
186+
email: str
187+
last_login: datetime
188+
189+
table_args = [
190+
IndexConstraint(columns=["last_login"])
191+
]
192+
```
193+
"""
194+
116195
columns: list[str]
196+
"""
197+
List of column names to create an index on
198+
"""
117199

118200

119201
INTERNAL_TABLE_FIELDS = ["modified_attrs"]
120202

121203

122204
class TableBase(BaseModel, metaclass=DBModelMetaclass):
205+
"""
206+
Base class for all database table models.
207+
Provides the foundation for defining database tables using Python classes with
208+
type hints and field definitions.
209+
210+
Features:
211+
- Automatic table name generation from class name
212+
- Support for custom table names
213+
- Tracking of modified fields for efficient updates
214+
- Support for unique constraints and indexes
215+
- Integration with Pydantic for validation
216+
217+
```python {{sticky: True}}
218+
class User(TableBase):
219+
# Custom table name (optional)
220+
table_name = "users"
221+
222+
# Fields with types and constraints
223+
id: int = Field(primary_key=True)
224+
email: str = Field(unique=True)
225+
name: str
226+
is_active: bool = Field(default=True)
227+
228+
# Table-level constraints
229+
table_args = [
230+
UniqueConstraint(columns=["email"]),
231+
IndexConstraint(columns=["name"])
232+
]
233+
234+
# Usage in queries
235+
query = select(User).where(User.is_active == True)
236+
users = await conn.execute(query)
237+
```
238+
"""
239+
123240
if TYPE_CHECKING:
124241
model_fields: ClassVar[dict[str, DBFieldInfo]] # type: ignore
125242

126243
table_name: ClassVar[str] = PydanticUndefined # type: ignore
244+
"""
245+
Optional custom name for the table
246+
"""
247+
127248
table_args: ClassVar[list[UniqueConstraint | IndexConstraint]] = PydanticUndefined # type: ignore
249+
"""
250+
Table constraints and indexes
251+
"""
128252

129253
# Private methods
130254
modified_attrs: dict[str, Any] = Field(default_factory=dict, exclude=True)
255+
"""
256+
Dictionary of modified field values since instantiation or the last clear_modified_attributes() call.
257+
Used to construct differential update queries.
258+
"""
259+
260+
def __setattr__(self, name: str, value: Any) -> None:
261+
"""
262+
Track modified attributes when fields are updated.
263+
This allows for efficient database updates by only updating changed fields.
131264
132-
def __setattr__(self, name, value):
265+
:param name: Attribute name
266+
:param value: New value
267+
"""
133268
if name in self.model_fields:
134269
self.modified_attrs[name] = value
135270
super().__setattr__(name, value)
136271

137272
def get_modified_attributes(self) -> dict[str, Any]:
273+
"""
274+
Get the dictionary of attributes that have been modified since instantiation
275+
or the last clear_modified_attributes() call.
276+
277+
:return: Dictionary of modified attribute names and their values
278+
"""
138279
return self.modified_attrs
139280

140-
def clear_modified_attributes(self):
281+
def clear_modified_attributes(self) -> None:
282+
"""
283+
Clear the tracking of modified attributes.
284+
Typically called after successfully saving changes to the database.
285+
"""
141286
self.modified_attrs.clear()
142287

143288
@classmethod
144-
def get_table_name(cls):
289+
def get_table_name(cls) -> str:
290+
"""
291+
Get the table name for this model.
292+
Uses the custom table_name if set, otherwise converts the class name to lowercase.
293+
294+
:return: Table name to use in SQL queries
295+
"""
145296
if cls.table_name == PydanticUndefined:
146297
return cls.__name__.lower()
147298
return cls.table_name
148299

149300
@classmethod
150-
def get_client_fields(cls):
301+
def get_client_fields(cls) -> dict[str, DBFieldInfo]:
302+
"""
303+
Get all fields that should be exposed to clients.
304+
Excludes internal fields used for model functionality.
305+
306+
:return: Dictionary of field names to field definitions
307+
"""
151308
return {
152309
field: info
153310
for field, info in cls.model_fields.items()
154311
if field not in INTERNAL_TABLE_FIELDS
155312
}
156313

157314
@classmethod
158-
def select_fields(cls):
315+
def select_fields(cls) -> QueryLiteral:
159316
"""
160-
Returns a query selectable string that can be used to select all fields
161-
from this model. This is the format that needs to be passed to our parser
162-
to serialize the raw postgres field values as TableBase objects.
317+
Generate a SQL-safe string for selecting all fields from this table.
318+
The output format is "{table_name}.{field_name} as {table_name}_{field_name}"
319+
for each field, which ensures proper field name resolution in complex queries.
163320
164-
The exact format is formatted as:
165-
"{table_name}.{field_name} as {table_name}_{field_name}".
321+
:return: SQL-safe field selection string
166322
323+
```python {{sticky: True}}
324+
# For a User class with fields 'id' and 'name':
325+
User.select_fields()
326+
# Returns: '"users"."id" as "users_id", "users"."name" as "users_name"'
327+
```
167328
"""
168329
table_token = QueryIdentifier(cls.get_table_name())
169330
select_fields: list[str] = []

0 commit comments

Comments
 (0)