Skip to content

Commit 7ee95cd

Browse files
Merge pull request #52 from piercefreeman/feature/batch-update
Add batch updates
2 parents 4f0c1aa + c801d72 commit 7ee95cd

File tree

4 files changed

+330
-142
lines changed

4 files changed

+330
-142
lines changed

iceaxe/__tests__/test_session.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from contextlib import asynccontextmanager
12
from enum import StrEnum
3+
from typing import Type
4+
from unittest.mock import AsyncMock
25

36
import asyncpg
47
import pytest
@@ -12,13 +15,14 @@
1215
UserDemo,
1316
)
1417
from iceaxe.alias_values import alias
15-
from iceaxe.base import TableBase
18+
from iceaxe.base import INTERNAL_TABLE_FIELDS, TableBase
1619
from iceaxe.field import Field
1720
from iceaxe.functions import func
1821
from iceaxe.queries import QueryBuilder
1922
from iceaxe.queries_str import sql
2023
from iceaxe.schemas.cli import create_all
2124
from iceaxe.session import (
25+
PG_MAX_PARAMETERS,
2226
DBConnection,
2327
)
2428
from iceaxe.typing import column
@@ -1112,3 +1116,162 @@ async def test_json_upsert(db_connection: DBConnection):
11121116
assert len(result) == 1
11131117
assert result[0].settings == {"theme": "dark", "notifications": True}
11141118
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
1119+
1120+
1121+
@pytest.mark.asyncio
1122+
async def test_db_connection_update_batched(db_connection: DBConnection):
1123+
"""Test that updates are properly batched when dealing with many objects and different field combinations."""
1124+
# Create test data with different update patterns
1125+
users_group1 = [
1126+
UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10)
1127+
]
1128+
users_group2 = [
1129+
UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10, 20)
1130+
]
1131+
users_group3 = [
1132+
UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(20, 30)
1133+
]
1134+
all_users = users_group1 + users_group2 + users_group3
1135+
await db_connection.insert(all_users)
1136+
1137+
# Modify different fields for different groups to test batching by modified fields
1138+
for user in users_group1:
1139+
user.name = f"Updated{user.name}" # Only name modified
1140+
1141+
for user in users_group2:
1142+
user.email = f"updated_{user.email}" # Only email modified
1143+
1144+
for user in users_group3:
1145+
user.name = f"Updated{user.name}" # Both fields modified
1146+
user.email = f"updated_{user.email}"
1147+
1148+
await db_connection.update(all_users)
1149+
1150+
# Verify all updates were applied correctly
1151+
result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id")
1152+
assert len(result) == 30
1153+
1154+
# Check group 1 (only names updated)
1155+
for i, row in enumerate(result[:10]):
1156+
assert row["name"] == f"UpdatedUser{i}"
1157+
assert row["email"] == f"user{i}@example.com"
1158+
1159+
# Check group 2 (only emails updated)
1160+
for i, row in enumerate(result[10:20]):
1161+
assert row["name"] == f"User{i+10}"
1162+
assert row["email"] == f"updated_user{i+10}@example.com"
1163+
1164+
# Check group 3 (both fields updated)
1165+
for i, row in enumerate(result[20:30]):
1166+
assert row["name"] == f"UpdatedUser{i+20}"
1167+
assert row["email"] == f"updated_user{i+20}@example.com"
1168+
1169+
# Verify all modifications were cleared
1170+
assert all(user.get_modified_attributes() == {} for user in all_users)
1171+
1172+
1173+
#
1174+
# Batch query construction
1175+
#
1176+
1177+
1178+
def assert_expected_user_fields(user: Type[UserDemo]):
1179+
# Verify UserDemo structure hasn't changed - if this fails, update the parameter calculations below
1180+
assert {
1181+
key for key in UserDemo.model_fields.keys() if key not in INTERNAL_TABLE_FIELDS
1182+
} == {"id", "name", "email"}
1183+
assert UserDemo.model_fields["id"].primary_key
1184+
assert UserDemo.model_fields["id"].default is None
1185+
return True
1186+
1187+
1188+
@asynccontextmanager
1189+
async def mock_transaction():
1190+
yield
1191+
1192+
1193+
@pytest.mark.asyncio
1194+
async def test_batch_insert_exceeds_parameters():
1195+
"""
1196+
Test that insert() correctly batches operations when we exceed Postgres parameter limits.
1197+
We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS.
1198+
"""
1199+
assert assert_expected_user_fields(UserDemo)
1200+
1201+
# Mock the connection
1202+
mock_conn = AsyncMock()
1203+
mock_conn.fetchmany = AsyncMock(return_value=[{"id": i} for i in range(1000)])
1204+
mock_conn.executemany = AsyncMock()
1205+
mock_conn.transaction = mock_transaction
1206+
1207+
db = DBConnection(mock_conn)
1208+
1209+
# Calculate how many objects we need to exceed the parameter limit
1210+
# Each object has 2 fields (name, email) in UserDemo
1211+
# So each object uses 2 parameters
1212+
objects_needed = (PG_MAX_PARAMETERS // 2) + 1
1213+
users = [
1214+
UserDemo(name=f"User {i}", email=f"user{i}@example.com")
1215+
for i in range(objects_needed)
1216+
]
1217+
1218+
# Insert the objects
1219+
await db.insert(users)
1220+
1221+
# We should have made at least 2 calls to fetchmany since we exceeded the parameter limit
1222+
assert len(mock_conn.fetchmany.mock_calls) >= 2
1223+
1224+
# Verify the structure of the first call
1225+
first_call = mock_conn.fetchmany.mock_calls[0]
1226+
assert "INSERT INTO" in first_call.args[0]
1227+
assert '"name"' in first_call.args[0]
1228+
assert '"email"' in first_call.args[0]
1229+
assert "RETURNING" in first_call.args[0]
1230+
1231+
1232+
@pytest.mark.asyncio
1233+
async def test_batch_update_exceeds_parameters():
1234+
"""
1235+
Test that update() correctly batches operations when we exceed Postgres parameter limits.
1236+
We'll create enough objects with enough modified fields that a single query would exceed PG_MAX_PARAMETERS.
1237+
"""
1238+
assert assert_expected_user_fields(UserDemo)
1239+
1240+
# Mock the connection
1241+
mock_conn = AsyncMock()
1242+
mock_conn.executemany = AsyncMock()
1243+
mock_conn.transaction = mock_transaction
1244+
1245+
db = DBConnection(mock_conn)
1246+
1247+
# Calculate how many objects we need to exceed the parameter limit
1248+
# Each UPDATE row needs:
1249+
# - 1 parameter for WHERE clause (id)
1250+
# - 2 parameters for SET clause (name, email)
1251+
# So each object uses 3 parameters
1252+
objects_needed = (PG_MAX_PARAMETERS // 3) + 1
1253+
users: list[UserDemo] = []
1254+
1255+
# Create objects and mark all fields as modified
1256+
for i in range(objects_needed):
1257+
user = UserDemo(id=i, name=f"User {i}", email=f"user{i}@example.com")
1258+
user.clear_modified_attributes()
1259+
1260+
# Simulate modifications to both fields
1261+
user.name = f"New User {i}"
1262+
user.email = f"newuser{i}@example.com"
1263+
1264+
users.append(user)
1265+
1266+
# Update the objects
1267+
await db.update(users)
1268+
1269+
# We should have made at least 2 calls to executemany since we exceeded the parameter limit
1270+
assert len(mock_conn.executemany.mock_calls) >= 2
1271+
1272+
# Verify the structure of the first call
1273+
first_call = mock_conn.executemany.mock_calls[0]
1274+
assert "UPDATE" in first_call.args[0]
1275+
assert "SET" in first_call.args[0]
1276+
assert "WHERE" in first_call.args[0]
1277+
assert '"id"' in first_call.args[0]

0 commit comments

Comments
 (0)