Skip to content

Commit 4f0c1aa

Browse files
Merge pull request #51 from piercefreeman/feature/optimize-insert
Optimize insertion speed
2 parents c6ba7f4 + 188c4c7 commit 4f0c1aa

File tree

2 files changed

+135
-25
lines changed

2 files changed

+135
-25
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import time
2+
from typing import Sequence
3+
4+
import pytest
5+
6+
from iceaxe.__tests__.conf_models import UserDemo
7+
from iceaxe.logging import CONSOLE, LOGGER
8+
from iceaxe.session import DBConnection
9+
10+
11+
def generate_test_users(count: int) -> Sequence[UserDemo]:
12+
"""
13+
Generate a sequence of test users for bulk insertion.
14+
15+
:param count: Number of users to generate
16+
:return: Sequence of UserDemo instances
17+
"""
18+
return [
19+
UserDemo(name=f"User {i}", email=f"user{i}@example.com") for i in range(count)
20+
]
21+
22+
23+
@pytest.mark.asyncio
24+
@pytest.mark.integration_tests
25+
async def test_bulk_insert_performance(db_connection: DBConnection):
26+
"""
27+
Test the performance of bulk inserting 500k records.
28+
"""
29+
NUM_USERS = 500_000
30+
users = generate_test_users(NUM_USERS)
31+
LOGGER.info(f"Generated {NUM_USERS} test users")
32+
33+
start_time = time.time()
34+
35+
await db_connection.insert(users)
36+
37+
total_time = time.time() - start_time
38+
records_per_second = NUM_USERS / total_time
39+
40+
CONSOLE.print("\nBulk Insert Performance:")
41+
CONSOLE.print(f"Total time: {total_time:.2f} seconds")
42+
CONSOLE.print(f"Records per second: {records_per_second:.2f}")
43+
44+
result = await db_connection.conn.fetchval("SELECT COUNT(*) FROM userdemo")
45+
assert result == NUM_USERS

iceaxe/session.py

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22
from contextlib import asynccontextmanager
33
from json import loads as json_loads
4+
from math import ceil
45
from typing import (
56
Any,
67
Literal,
@@ -33,6 +34,9 @@
3334

3435
TableType = TypeVar("TableType", bound=TableBase)
3536

37+
# PostgreSQL has a limit of 65535 parameters per query
38+
PG_MAX_PARAMETERS = 65535
39+
3640

3741
class DBConnection:
3842
"""
@@ -235,37 +239,98 @@ async def insert(self, objects: Sequence[TableBase]):
235239
if not objects:
236240
return
237241

238-
for model, model_objects in self._aggregate_models_by_table(objects):
239-
table_name = QueryIdentifier(model.get_table_name())
240-
fields = {
241-
field: info
242-
for field, info in model.model_fields.items()
243-
if (not info.exclude and not info.autoincrement)
244-
}
245-
field_string = ", ".join(f'"{field}"' for field in fields)
246-
primary_key = self._get_primary_key(model)
247-
248-
placeholders = ", ".join(f"${i}" for i in range(1, len(fields) + 1))
249-
query = f"INSERT INTO {table_name} ({field_string}) VALUES ({placeholders})"
250-
if primary_key:
251-
query += f" RETURNING {primary_key}"
242+
# Reuse a single transaction for all inserts
243+
async with self._ensure_transaction():
244+
for model, model_objects in self._aggregate_models_by_table(objects):
245+
# For each table, build batched insert queries
246+
table_name = QueryIdentifier(model.get_table_name())
247+
fields = {
248+
field: info
249+
for field, info in model.model_fields.items()
250+
if (not info.exclude and not info.autoincrement)
251+
}
252+
primary_key = self._get_primary_key(model)
253+
field_names = list(
254+
fields.keys()
255+
) # Iterate over these in order for each row
256+
field_identifiers = ", ".join(f'"{f}"' for f in field_names)
257+
258+
# Calculate max batch size based on number of fields
259+
# Each row uses len(fields) parameters, so max_batch_size * len(fields) <= PG_MAX_PARAMETERS
260+
max_batch_size = PG_MAX_PARAMETERS // len(fields)
261+
# Cap at 5000 rows per batch to avoid excessive memory usage
262+
max_batch_size = min(max_batch_size, 5000)
263+
264+
total = len(model_objects)
265+
num_batches = ceil(total / max_batch_size)
266+
267+
for batch_idx in range(num_batches):
268+
start_idx = batch_idx * max_batch_size
269+
end_idx = (batch_idx + 1) * max_batch_size
270+
batch_objects = model_objects[start_idx:end_idx]
271+
272+
# Build the multi-row VALUES clause
273+
# e.g. for 3 rows with 2 columns, we'd want:
274+
# VALUES ($1, $2), ($3, $4), ($5, $6)
275+
num_rows = len(batch_objects)
276+
if not num_rows:
277+
continue
252278

253-
async with self._ensure_transaction():
254-
for obj in model_objects:
255-
obj_values = obj.model_dump()
256-
values = [
257-
info.to_db_value(obj_values[field])
258-
for field, info in fields.items()
259-
]
260-
result = await self.conn.fetchrow(query, *values)
279+
# placeholders per row: ($1, $2, ...)
280+
# but we have to shift the placeholder index for each row
281+
placeholders: list[str] = []
282+
values: list[Any] = []
283+
param_index = 1
284+
285+
for obj in batch_objects:
286+
obj_values = obj.model_dump()
287+
row_values = []
288+
for field in field_names:
289+
info = fields[field]
290+
row_values.append(info.to_db_value(obj_values[field]))
291+
values.extend(row_values)
292+
row_placeholder = (
293+
"("
294+
+ ", ".join(
295+
f"${p}"
296+
for p in range(
297+
param_index, param_index + len(field_names)
298+
)
299+
)
300+
+ ")"
301+
)
302+
placeholders.append(row_placeholder)
303+
param_index += len(field_names)
304+
305+
placeholders_clause = ", ".join(placeholders)
306+
307+
query = f"""
308+
INSERT INTO {table_name} ({field_identifiers})
309+
VALUES {placeholders_clause}
310+
"""
311+
if primary_key:
312+
query += f" RETURNING {primary_key}"
313+
314+
# Insert them in one go
315+
if primary_key:
316+
rows = await self.conn.fetch(query, *values)
317+
# 'rows' should be a list of Record objects, one per inserted row
318+
# Update each object in the same order
319+
for obj, row in zip(batch_objects, rows):
320+
setattr(obj, primary_key, row[primary_key])
321+
else:
322+
# No need to fetch anything if there's no primary key
323+
await self.conn.execute(query, *values)
261324

262-
if primary_key and result:
263-
setattr(obj, primary_key, result[primary_key])
264-
obj.clear_modified_attributes()
325+
# Mark as unmodified
326+
for obj in batch_objects:
327+
obj.clear_modified_attributes()
265328

329+
# Register modification callbacks outside the main insert loop
266330
for obj in objects:
267331
obj.register_modified_callback(self.modification_tracker.track_modification)
268332

333+
# Clear modification status
269334
self.modification_tracker.clear_status(objects)
270335

271336
@overload

0 commit comments

Comments
 (0)