Skip to content

Commit f3c4e7a

Browse files
Merge pull request #49 from piercefreeman/feature/group-by-function-support
Support function use in group by
2 parents eb99385 + 4ee29b7 commit f3c4e7a

File tree

7 files changed

+181
-28
lines changed

7 files changed

+181
-28
lines changed

iceaxe/__tests__/conf_models.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pyinstrument import Profiler
77

8-
from iceaxe.base import Field, TableBase
8+
from iceaxe.base import Field, TableBase, UniqueConstraint
99

1010

1111
class UserDemo(TableBase):
@@ -103,6 +103,19 @@ class DemoModelB(TableBase):
103103
code: str = Field(unique=True)
104104

105105

106+
class JsonDemo(TableBase):
107+
"""
108+
Model for testing JSON field updates.
109+
"""
110+
111+
id: int | None = Field(primary_key=True, default=None)
112+
settings: dict[Any, Any] = Field(is_json=True)
113+
metadata: dict[Any, Any] | None = Field(is_json=True)
114+
unique_val: str
115+
116+
table_args = [UniqueConstraint(columns=["unique_val"])]
117+
118+
106119
@contextmanager
107120
def run_profile(request):
108121
TESTS_ROOT = Path.cwd()

iceaxe/__tests__/test_queries.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,3 +586,20 @@ def test_multiple_group_by():
586586
'GROUP BY "employee"."department", "employee"."last_name"',
587587
[],
588588
)
589+
590+
591+
def test_group_by_with_function():
592+
new_query = (
593+
QueryBuilder()
594+
.select(
595+
(
596+
func.date_trunc("month", FunctionDemoModel.created_at),
597+
func.count(FunctionDemoModel.id),
598+
)
599+
)
600+
.group_by(func.date_trunc("month", FunctionDemoModel.created_at))
601+
)
602+
assert new_query.build() == (
603+
'SELECT date_trunc(\'month\', "functiondemomodel"."created_at") AS aggregate_0, count("functiondemomodel"."id") AS aggregate_1 FROM "functiondemomodel" GROUP BY date_trunc(\'month\', "functiondemomodel"."created_at")',
604+
[],
605+
)

iceaxe/__tests__/test_session.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ComplexDemo,
99
DemoModelA,
1010
DemoModelB,
11+
JsonDemo,
1112
UserDemo,
1213
)
1314
from iceaxe.alias_values import alias
@@ -1028,3 +1029,86 @@ async def test_select_with_order_by_func_count(db_connection: DBConnection):
10281029
assert result[1] == ("Jane", 1)
10291030
# Bob has 0 posts
10301031
assert result[2] == ("Bob", 0)
1032+
1033+
1034+
@pytest.mark.asyncio
1035+
async def test_json_update(db_connection: DBConnection):
1036+
"""
1037+
Test that JSON fields are correctly serialized during updates.
1038+
"""
1039+
# Create the table first
1040+
await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
1041+
await create_all(db_connection, [JsonDemo])
1042+
1043+
# Create initial object with JSON data
1044+
demo = JsonDemo(
1045+
settings={"theme": "dark", "notifications": True},
1046+
metadata={"version": 1},
1047+
unique_val="1",
1048+
)
1049+
await db_connection.insert([demo])
1050+
1051+
# Update JSON fields
1052+
demo.settings = {"theme": "light", "notifications": False}
1053+
demo.metadata = {"version": 2, "last_updated": "2024-01-01"}
1054+
await db_connection.update([demo])
1055+
1056+
# Verify the update through a fresh select
1057+
result = await db_connection.exec(
1058+
QueryBuilder().select(JsonDemo).where(JsonDemo.id == demo.id)
1059+
)
1060+
assert len(result) == 1
1061+
assert result[0].settings == {"theme": "light", "notifications": False}
1062+
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
1063+
1064+
1065+
@pytest.mark.asyncio
1066+
async def test_json_upsert(db_connection: DBConnection):
1067+
"""
1068+
Test that JSON fields are correctly serialized during upsert operations.
1069+
"""
1070+
# Create the table first
1071+
await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
1072+
await create_all(db_connection, [JsonDemo])
1073+
1074+
# Initial insert via upsert
1075+
demo = JsonDemo(
1076+
settings={"theme": "dark", "notifications": True},
1077+
metadata={"version": 1},
1078+
unique_val="1",
1079+
)
1080+
result = await db_connection.upsert(
1081+
[demo],
1082+
conflict_fields=(JsonDemo.unique_val,),
1083+
update_fields=(JsonDemo.metadata,),
1084+
returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
1085+
)
1086+
1087+
assert result is not None
1088+
assert len(result) == 1
1089+
assert result[0][0] == "1"
1090+
assert result[0][1] == {"version": 1}
1091+
1092+
# Update via upsert
1093+
demo2 = JsonDemo(
1094+
settings={"theme": "dark", "notifications": True},
1095+
metadata={"version": 2, "last_updated": "2024-01-01"}, # New metadata
1096+
unique_val="1", # Same value to trigger update
1097+
)
1098+
result = await db_connection.upsert(
1099+
[demo2],
1100+
conflict_fields=(JsonDemo.unique_val,),
1101+
update_fields=(JsonDemo.metadata,),
1102+
returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
1103+
)
1104+
1105+
assert result is not None
1106+
assert len(result) == 1
1107+
assert result[0][0] == "1"
1108+
assert result[0][1] == {"version": 2, "last_updated": "2024-01-01"}
1109+
1110+
# Verify through a fresh select
1111+
result = await db_connection.exec(QueryBuilder().select(JsonDemo))
1112+
assert len(result) == 1
1113+
assert result[0].settings == {"theme": "dark", "notifications": True}
1114+
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}

iceaxe/mountaineer/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from pydantic_settings import BaseSettings
22

3+
from iceaxe.modifications import MODIFICATION_TRACKER_VERBOSITY
4+
35

46
class DatabaseConfig(BaseSettings):
57
"""
@@ -36,3 +38,9 @@ class DatabaseConfig(BaseSettings):
3638
The port number where PostgreSQL server is listening.
3739
Defaults to the standard PostgreSQL port 5432 if not specified.
3840
"""
41+
42+
ICEAXE_UNCOMMITTED_VERBOSITY: MODIFICATION_TRACKER_VERBOSITY | None = None
43+
"""
44+
The verbosity level for uncommitted modifications.
45+
If set to None, uncommitted modifications will not be tracked.
46+
"""

iceaxe/mountaineer/dependencies/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ async def get_users(db: DBConnection = Depends(get_db_connection)):
5353
password=config.POSTGRES_PASSWORD,
5454
database=config.POSTGRES_DB,
5555
)
56-
connection = DBConnection(conn)
56+
connection = DBConnection(
57+
conn, uncommitted_verbosity=config.ICEAXE_UNCOMMITTED_VERBOSITY
58+
)
5759
try:
5860
yield connection
5961
finally:

iceaxe/queries.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self):
144144
self._join_clauses: list[str] = []
145145
self._limit_value: int | None = None
146146
self._offset_value: int | None = None
147-
self._group_by_fields: list[DBFieldClassDefinition] = []
147+
self._group_by_clauses: list[str] = []
148148
self._having_conditions: list[FieldComparison] = []
149149
self._distinct_on_fields: list[QueryElementBase] = []
150150
self._for_update_config: ForUpdateConfig = ForUpdateConfig()
@@ -771,9 +771,14 @@ def group_by(self, *fields: Any):
771771
"""
772772

773773
for field in fields:
774-
if not is_column(field):
775-
raise ValueError(f"Invalid field for group by: {field}")
776-
self._group_by_fields.append(field)
774+
if is_column(field):
775+
field_token, _ = field.to_query()
776+
elif is_function_metadata(field):
777+
field_token = field.literal
778+
else:
779+
raise ValueError(f"Invalid group by field: {field}")
780+
781+
self._group_by_clauses.append(str(field_token))
777782

778783
return self
779784

@@ -987,9 +992,9 @@ def build(self) -> tuple[str, list[Any]]:
987992
query += f" WHERE {comparison_literal}"
988993
variables += comparison_variables
989994

990-
if self._group_by_fields:
995+
if self._group_by_clauses:
991996
query += " GROUP BY "
992-
query += ", ".join(str(sql(field)) for field in self._group_by_fields)
997+
query += ", ".join(str(field) for field in self._group_by_clauses)
993998

994999
if self._having_conditions:
9951000
query += " HAVING "

iceaxe/session.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22
from contextlib import asynccontextmanager
3+
from json import loads as json_loads
34
from typing import (
45
Any,
56
Literal,
@@ -14,7 +15,7 @@
1415
import asyncpg
1516
from typing_extensions import TypeVarTuple
1617

17-
from iceaxe.base import TableBase
18+
from iceaxe.base import DBFieldClassDefinition, TableBase
1819
from iceaxe.logging import LOGGER
1920
from iceaxe.modifications import ModificationTracker
2021
from iceaxe.queries import (
@@ -81,6 +82,7 @@ class User(TableBase):
8182
def __init__(
8283
self,
8384
conn: asyncpg.Connection,
85+
*,
8486
uncommitted_verbosity: Literal["ERROR", "WARNING", "INFO"] | None = None,
8587
):
8688
"""
@@ -273,8 +275,8 @@ async def upsert(
273275
*,
274276
conflict_fields: tuple[Any, ...],
275277
update_fields: tuple[Any, ...] | None = None,
276-
returning_fields: tuple[T, *Ts],
277-
) -> list[tuple[T, *Ts]]: ...
278+
returning_fields: tuple[T, *Ts] | None = None,
279+
) -> list[tuple[T, *Ts]] | None: ...
278280

279281
@overload
280282
async def upsert(
@@ -332,13 +334,26 @@ async def upsert(
332334
return None
333335

334336
# Evaluate column types
335-
conflict_fields_cols = [field for field in conflict_fields if is_column(field)]
336-
update_fields_cols = [
337-
field for field in update_fields or [] if is_column(field)
338-
]
339-
returning_fields_cols = [
340-
field for field in returning_fields or [] if is_column(field)
341-
]
337+
conflict_fields_cols: list[DBFieldClassDefinition] = []
338+
update_fields_cols: list[DBFieldClassDefinition] = []
339+
returning_fields_cols: list[DBFieldClassDefinition] = []
340+
341+
# Explicitly validate types of all columns
342+
for field in conflict_fields:
343+
if is_column(field):
344+
conflict_fields_cols.append(field)
345+
else:
346+
raise ValueError(f"Field {field} is not a column")
347+
for field in update_fields or []:
348+
if is_column(field):
349+
update_fields_cols.append(field)
350+
else:
351+
raise ValueError(f"Field {field} is not a column")
352+
for field in returning_fields or []:
353+
if is_column(field):
354+
returning_fields_cols.append(field)
355+
else:
356+
raise ValueError(f"Field {field} is not a column")
342357

343358
results: list[tuple[T, *Ts]] = []
344359
async with self._ensure_transaction():
@@ -387,14 +402,17 @@ async def upsert(
387402
if returning_fields_cols:
388403
result = await self.conn.fetchrow(query, *values)
389404
if result:
390-
results.append(
391-
tuple(
392-
[
393-
result[field.key]
394-
for field in returning_fields_cols
395-
]
396-
)
397-
)
405+
# Process returned values, deserializing JSON if needed
406+
processed_values = []
407+
for field in returning_fields_cols:
408+
value = result[field.key]
409+
if (
410+
value is not None
411+
and field.root_model.model_fields[field.key].is_json
412+
):
413+
value = json_loads(value)
414+
processed_values.append(value)
415+
results.append(tuple(processed_values))
398416
else:
399417
await self.conn.execute(query, *values)
400418

@@ -441,7 +459,7 @@ async def update(self, objects: Sequence[TableBase]):
441459

442460
for obj in model_objects:
443461
modified_attrs = {
444-
k: v
462+
k: obj.model_fields[k].to_db_value(v)
445463
for k, v in obj.get_modified_attributes().items()
446464
if not obj.model_fields[k].exclude
447465
}
@@ -455,7 +473,13 @@ async def update(self, objects: Sequence[TableBase]):
455473

456474
query = f"UPDATE {table_name} SET {set_clause} WHERE {primary_key_name} = $1"
457475
values = [getattr(obj, primary_key)] + list(modified_attrs.values())
458-
await self.conn.execute(query, *values)
476+
try:
477+
await self.conn.execute(query, *values)
478+
except Exception as e:
479+
LOGGER.error(
480+
f"Error executing query: {query} with variables: {values}"
481+
)
482+
raise e
459483
obj.clear_modified_attributes()
460484

461485
self.modification_tracker.clear_status(objects)

0 commit comments

Comments
 (0)