|
| 1 | +from contextlib import asynccontextmanager |
1 | 2 | from enum import StrEnum
|
| 3 | +from typing import Type |
| 4 | +from unittest.mock import AsyncMock |
2 | 5 |
|
3 | 6 | import asyncpg
|
4 | 7 | import pytest
|
|
12 | 15 | UserDemo,
|
13 | 16 | )
|
14 | 17 | from iceaxe.alias_values import alias
|
15 |
| -from iceaxe.base import TableBase |
| 18 | +from iceaxe.base import INTERNAL_TABLE_FIELDS, TableBase |
16 | 19 | from iceaxe.field import Field
|
17 | 20 | from iceaxe.functions import func
|
18 | 21 | from iceaxe.queries import QueryBuilder
|
19 | 22 | from iceaxe.queries_str import sql
|
20 | 23 | from iceaxe.schemas.cli import create_all
|
21 | 24 | from iceaxe.session import (
|
| 25 | + PG_MAX_PARAMETERS, |
22 | 26 | DBConnection,
|
23 | 27 | )
|
24 | 28 | from iceaxe.typing import column
|
@@ -1112,3 +1116,162 @@ async def test_json_upsert(db_connection: DBConnection):
|
1112 | 1116 | assert len(result) == 1
|
1113 | 1117 | assert result[0].settings == {"theme": "dark", "notifications": True}
|
1114 | 1118 | assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
|
| 1119 | + |
| 1120 | + |
| 1121 | +@pytest.mark.asyncio |
| 1122 | +async def test_db_connection_update_batched(db_connection: DBConnection): |
| 1123 | + """Test that updates are properly batched when dealing with many objects and different field combinations.""" |
| 1124 | + # Create test data with different update patterns |
| 1125 | + users_group1 = [ |
| 1126 | + UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10) |
| 1127 | + ] |
| 1128 | + users_group2 = [ |
| 1129 | + UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(10, 20) |
| 1130 | + ] |
| 1131 | + users_group3 = [ |
| 1132 | + UserDemo(name=f"User{i}", email=f"user{i}@example.com") for i in range(20, 30) |
| 1133 | + ] |
| 1134 | + all_users = users_group1 + users_group2 + users_group3 |
| 1135 | + await db_connection.insert(all_users) |
| 1136 | + |
| 1137 | + # Modify different fields for different groups to test batching by modified fields |
| 1138 | + for user in users_group1: |
| 1139 | + user.name = f"Updated{user.name}" # Only name modified |
| 1140 | + |
| 1141 | + for user in users_group2: |
| 1142 | + user.email = f"updated_{user.email}" # Only email modified |
| 1143 | + |
| 1144 | + for user in users_group3: |
| 1145 | + user.name = f"Updated{user.name}" # Both fields modified |
| 1146 | + user.email = f"updated_{user.email}" |
| 1147 | + |
| 1148 | + await db_connection.update(all_users) |
| 1149 | + |
| 1150 | + # Verify all updates were applied correctly |
| 1151 | + result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY id") |
| 1152 | + assert len(result) == 30 |
| 1153 | + |
| 1154 | + # Check group 1 (only names updated) |
| 1155 | + for i, row in enumerate(result[:10]): |
| 1156 | + assert row["name"] == f"UpdatedUser{i}" |
| 1157 | + assert row["email"] == f"user{i}@example.com" |
| 1158 | + |
| 1159 | + # Check group 2 (only emails updated) |
| 1160 | + for i, row in enumerate(result[10:20]): |
| 1161 | + assert row["name"] == f"User{i+10}" |
| 1162 | + assert row["email"] == f"updated_user{i+10}@example.com" |
| 1163 | + |
| 1164 | + # Check group 3 (both fields updated) |
| 1165 | + for i, row in enumerate(result[20:30]): |
| 1166 | + assert row["name"] == f"UpdatedUser{i+20}" |
| 1167 | + assert row["email"] == f"updated_user{i+20}@example.com" |
| 1168 | + |
| 1169 | + # Verify all modifications were cleared |
| 1170 | + assert all(user.get_modified_attributes() == {} for user in all_users) |
| 1171 | + |
| 1172 | + |
| 1173 | +# |
| 1174 | +# Batch query construction |
| 1175 | +# |
| 1176 | + |
| 1177 | + |
| 1178 | +def assert_expected_user_fields(user: Type[UserDemo]): |
| 1179 | + # Verify UserDemo structure hasn't changed - if this fails, update the parameter calculations below |
| 1180 | + assert { |
| 1181 | + key for key in UserDemo.model_fields.keys() if key not in INTERNAL_TABLE_FIELDS |
| 1182 | + } == {"id", "name", "email"} |
| 1183 | + assert UserDemo.model_fields["id"].primary_key |
| 1184 | + assert UserDemo.model_fields["id"].default is None |
| 1185 | + return True |
| 1186 | + |
| 1187 | + |
| 1188 | +@asynccontextmanager |
| 1189 | +async def mock_transaction(): |
| 1190 | + yield |
| 1191 | + |
| 1192 | + |
| 1193 | +@pytest.mark.asyncio |
| 1194 | +async def test_batch_insert_exceeds_parameters(): |
| 1195 | + """ |
| 1196 | + Test that insert() correctly batches operations when we exceed Postgres parameter limits. |
| 1197 | + We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS. |
| 1198 | + """ |
| 1199 | + assert assert_expected_user_fields(UserDemo) |
| 1200 | + |
| 1201 | + # Mock the connection |
| 1202 | + mock_conn = AsyncMock() |
| 1203 | + mock_conn.fetchmany = AsyncMock(return_value=[{"id": i} for i in range(1000)]) |
| 1204 | + mock_conn.executemany = AsyncMock() |
| 1205 | + mock_conn.transaction = mock_transaction |
| 1206 | + |
| 1207 | + db = DBConnection(mock_conn) |
| 1208 | + |
| 1209 | + # Calculate how many objects we need to exceed the parameter limit |
| 1210 | + # Each object has 2 fields (name, email) in UserDemo |
| 1211 | + # So each object uses 2 parameters |
| 1212 | + objects_needed = (PG_MAX_PARAMETERS // 2) + 1 |
| 1213 | + users = [ |
| 1214 | + UserDemo(name=f"User {i}", email=f"user{i}@example.com") |
| 1215 | + for i in range(objects_needed) |
| 1216 | + ] |
| 1217 | + |
| 1218 | + # Insert the objects |
| 1219 | + await db.insert(users) |
| 1220 | + |
| 1221 | + # We should have made at least 2 calls to fetchmany since we exceeded the parameter limit |
| 1222 | + assert len(mock_conn.fetchmany.mock_calls) >= 2 |
| 1223 | + |
| 1224 | + # Verify the structure of the first call |
| 1225 | + first_call = mock_conn.fetchmany.mock_calls[0] |
| 1226 | + assert "INSERT INTO" in first_call.args[0] |
| 1227 | + assert '"name"' in first_call.args[0] |
| 1228 | + assert '"email"' in first_call.args[0] |
| 1229 | + assert "RETURNING" in first_call.args[0] |
| 1230 | + |
| 1231 | + |
| 1232 | +@pytest.mark.asyncio |
| 1233 | +async def test_batch_update_exceeds_parameters(): |
| 1234 | + """ |
| 1235 | + Test that update() correctly batches operations when we exceed Postgres parameter limits. |
| 1236 | + We'll create enough objects with enough modified fields that a single query would exceed PG_MAX_PARAMETERS. |
| 1237 | + """ |
| 1238 | + assert assert_expected_user_fields(UserDemo) |
| 1239 | + |
| 1240 | + # Mock the connection |
| 1241 | + mock_conn = AsyncMock() |
| 1242 | + mock_conn.executemany = AsyncMock() |
| 1243 | + mock_conn.transaction = mock_transaction |
| 1244 | + |
| 1245 | + db = DBConnection(mock_conn) |
| 1246 | + |
| 1247 | + # Calculate how many objects we need to exceed the parameter limit |
| 1248 | + # Each UPDATE row needs: |
| 1249 | + # - 1 parameter for WHERE clause (id) |
| 1250 | + # - 2 parameters for SET clause (name, email) |
| 1251 | + # So each object uses 3 parameters |
| 1252 | + objects_needed = (PG_MAX_PARAMETERS // 3) + 1 |
| 1253 | + users: list[UserDemo] = [] |
| 1254 | + |
| 1255 | + # Create objects and mark all fields as modified |
| 1256 | + for i in range(objects_needed): |
| 1257 | + user = UserDemo(id=i, name=f"User {i}", email=f"user{i}@example.com") |
| 1258 | + user.clear_modified_attributes() |
| 1259 | + |
| 1260 | + # Simulate modifications to both fields |
| 1261 | + user.name = f"New User {i}" |
| 1262 | + user.email = f"newuser{i}@example.com" |
| 1263 | + |
| 1264 | + users.append(user) |
| 1265 | + |
| 1266 | + # Update the objects |
| 1267 | + await db.update(users) |
| 1268 | + |
| 1269 | + # We should have made at least 2 calls to executemany since we exceeded the parameter limit |
| 1270 | + assert len(mock_conn.executemany.mock_calls) >= 2 |
| 1271 | + |
| 1272 | + # Verify the structure of the first call |
| 1273 | + first_call = mock_conn.executemany.mock_calls[0] |
| 1274 | + assert "UPDATE" in first_call.args[0] |
| 1275 | + assert "SET" in first_call.args[0] |
| 1276 | + assert "WHERE" in first_call.args[0] |
| 1277 | + assert '"id"' in first_call.args[0] |
0 commit comments