Skip to content

Commit e5a2cc8

Browse files
authored
refactor: removed session._ensure_transaction (#68)
* refactor: removed session._ensure_transaction * chore: address pr comments
1 parent 98450f3 commit e5a2cc8

File tree

5 files changed

+73
-42
lines changed

5 files changed

+73
-42
lines changed

iceaxe/__tests__/benchmarks/test_select.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ async def test_benchmark(
107107
CONSOLE.print(f"Performance difference: {performance_diff:.2f}%")
108108

109109
# Assert that DBConnection.exec is at most X% slower than raw query
110-
assert (
111-
performance_diff <= allowed_overhead
112-
), f"DBConnection.exec is {performance_diff:.2f}% slower than raw query, which exceeds the {allowed_overhead}% threshold"
110+
assert performance_diff <= allowed_overhead, (
111+
f"DBConnection.exec is {performance_diff:.2f}% slower than raw query, which exceeds the {allowed_overhead}% threshold"
112+
)
113113

114114
LOGGER.info("Benchmark completed successfully.")

iceaxe/__tests__/schemas/test_db_memory_serializer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class ModelA(TableBase):
8585
},
8686
),
8787
DryRunComment(
88-
text="\n" "NEW TABLE: modela\n",
88+
text="\nNEW TABLE: modela\n",
8989
previous_line=False,
9090
),
9191
DryRunAction(
@@ -294,7 +294,7 @@ class Model2(TableBase):
294294
},
295295
),
296296
DryRunComment(
297-
text="\n" "NEW TABLE: model1\n",
297+
text="\nNEW TABLE: model1\n",
298298
previous_line=False,
299299
),
300300
DryRunAction(
@@ -350,7 +350,7 @@ class Model2(TableBase):
350350
},
351351
),
352352
DryRunComment(
353-
text="\n" "NEW TABLE: model2\n",
353+
text="\nNEW TABLE: model2\n",
354354
previous_line=False,
355355
),
356356
DryRunAction(
@@ -998,7 +998,7 @@ class ModelA(TableBase, GenericSuperclass[OldValues]):
998998
},
999999
),
10001000
DryRunComment(
1001-
text="\n" "NEW TABLE: modela\n",
1001+
text="\nNEW TABLE: modela\n",
10021002
previous_line=False,
10031003
),
10041004
DryRunAction(
@@ -1377,18 +1377,18 @@ class SourceModel(TableBase):
13771377
)
13781378

13791379
# The foreign key constraint should come after both tables and the target column are created
1380-
assert (
1381-
target_table_pos < fk_constraint_pos
1382-
), "Foreign key constraint should be created after target table"
1383-
assert (
1384-
source_table_pos < fk_constraint_pos
1385-
), "Foreign key constraint should be created after source table"
1386-
assert (
1387-
target_column_pos < fk_constraint_pos
1388-
), "Foreign key constraint should be created after target column"
1389-
assert (
1390-
target_pk_pos < fk_constraint_pos
1391-
), "Foreign key constraint should be created after target primary key"
1380+
assert target_table_pos < fk_constraint_pos, (
1381+
"Foreign key constraint should be created after target table"
1382+
)
1383+
assert source_table_pos < fk_constraint_pos, (
1384+
"Foreign key constraint should be created after source table"
1385+
)
1386+
assert target_column_pos < fk_constraint_pos, (
1387+
"Foreign key constraint should be created after target column"
1388+
)
1389+
assert target_pk_pos < fk_constraint_pos, (
1390+
"Foreign key constraint should be created after target primary key"
1391+
)
13921392

13931393
# Verify the actual migration actions
13941394
actor = DatabaseActions()

iceaxe/__tests__/test_session.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,13 +1128,13 @@ async def test_db_connection_update_batched(db_connection: DBConnection):
11281128

11291129
# Check group 2 (only emails updated)
11301130
for i, row in enumerate(result[10:20]):
1131-
assert row["name"] == f"User{i+10}"
1132-
assert row["email"] == f"updated_user{i+10}@example.com"
1131+
assert row["name"] == f"User{i + 10}"
1132+
assert row["email"] == f"updated_user{i + 10}@example.com"
11331133

11341134
# Check group 3 (both fields updated)
11351135
for i, row in enumerate(result[20:30]):
1136-
assert row["name"] == f"UpdatedUser{i+20}"
1137-
assert row["email"] == f"updated_user{i+20}@example.com"
1136+
assert row["name"] == f"UpdatedUser{i + 20}"
1137+
assert row["email"] == f"updated_user{i + 20}@example.com"
11381138

11391139
# Verify all modifications were cleared
11401140
assert all(user.get_modified_attributes() == {} for user in all_users)
@@ -1477,3 +1477,33 @@ async def test_get_dsn(db_connection: DBConnection):
14771477
assert "localhost" in dsn
14781478
assert "5438" in dsn
14791479
assert "iceaxe_test_db" in dsn
1480+
1481+
1482+
@pytest.mark.asyncio
1483+
async def test_nested_transactions(db_connection):
1484+
"""
1485+
Test that nested transactions raise an error by default, but work with ensure=True.
1486+
"""
1487+
# Start an outer transaction
1488+
async with db_connection.transaction():
1489+
# This should work fine
1490+
assert db_connection.in_transaction is True
1491+
1492+
# Nested transaction with ensure=True should work
1493+
async with db_connection.transaction(ensure=True):
1494+
assert db_connection.in_transaction is True
1495+
1496+
# Nested transaction without ensure should fail
1497+
with pytest.raises(
1498+
RuntimeError,
1499+
match="Cannot start a new transaction while already in a transaction",
1500+
):
1501+
async with db_connection.transaction():
1502+
pass # Should not reach here
1503+
1504+
# After outer transaction ends, we should be out of transaction
1505+
assert db_connection.in_transaction is False
1506+
1507+
# Now a new transaction should start without error
1508+
async with db_connection.transaction():
1509+
assert db_connection.in_transaction is True

iceaxe/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def log_time_duration(message: str):
7979
"""
8080
start = monotonic_ns()
8181
yield
82-
LOGGER.debug(f"{message} : Took {(monotonic_ns() - start)/1e9:.2f}s")
82+
LOGGER.debug(f"{message} : Took {(monotonic_ns() - start) / 1e9:.2f}s")
8383

8484

8585
# Our global logger should only surface warnings and above by default

iceaxe/session.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,14 @@ def get_dsn(self) -> str:
188188
return "".join(dsn_parts)
189189

190190
@asynccontextmanager
191-
async def transaction(self):
191+
async def transaction(self, *, ensure: bool = False):
192192
"""
193193
Context manager for managing database transactions. Ensures that a series of database
194194
operations are executed atomically.
195195
196+
:param ensure: If True and already in a transaction, the context manager will yield without creating a new transaction.
197+
If False (default) and already in a transaction, raises a RuntimeError.
198+
196199
```python {{sticky: True}}
197200
async with conn.transaction():
198201
# All operations here are executed in a transaction
@@ -205,6 +208,17 @@ async def transaction(self):
205208
# If any operation fails, all changes are rolled back
206209
```
207210
"""
211+
# If ensure is True and we're already in a transaction, just yield
212+
if self.in_transaction:
213+
if ensure:
214+
yield
215+
return
216+
else:
217+
raise RuntimeError(
218+
"Cannot start a new transaction while already in a transaction. Use ensure=True if this is intentional."
219+
)
220+
221+
# Otherwise, start a new transaction
208222
self.in_transaction = True
209223
async with self.conn.transaction():
210224
try:
@@ -325,7 +339,7 @@ async def insert(self, objects: Sequence[TableBase]):
325339
return
326340

327341
# Reuse a single transaction for all inserts
328-
async with self._ensure_transaction():
342+
async with self.transaction(ensure=True):
329343
for model, model_objects in self._aggregate_models_by_table(objects):
330344
# For each table, build batched insert queries
331345
table_name = QueryIdentifier(model.get_table_name())
@@ -466,7 +480,7 @@ async def upsert(
466480
raise ValueError(f"Field {field} is not a column")
467481

468482
results: list[tuple[T, *Ts]] = []
469-
async with self._ensure_transaction():
483+
async with self.transaction(ensure=True):
470484
for model, model_objects in self._aggregate_models_by_table(objects):
471485
table_name = QueryIdentifier(model.get_table_name())
472486
fields = {
@@ -559,7 +573,7 @@ async def update(self, objects: Sequence[TableBase]):
559573
if not objects:
560574
return
561575

562-
async with self._ensure_transaction():
576+
async with self.transaction(ensure=True):
563577
for model, model_objects in self._aggregate_models_by_table(objects):
564578
table_name = QueryIdentifier(model.get_table_name())
565579
primary_key = self._get_primary_key(model)
@@ -638,7 +652,7 @@ async def delete(self, objects: Sequence[TableBase]):
638652
:param objects: A sequence of TableBase instances to delete
639653
640654
"""
641-
async with self._ensure_transaction():
655+
async with self.transaction(ensure=True):
642656
for model, model_objects in self._aggregate_models_by_table(objects):
643657
table_name = QueryIdentifier(model.get_table_name())
644658
primary_key = self._get_primary_key(model)
@@ -797,19 +811,6 @@ def _get_primary_key(self, obj: Type[TableBase]) -> str | None:
797811
)
798812
return self.obj_to_primary_key[table_name]
799813

800-
@asynccontextmanager
801-
async def _ensure_transaction(self):
802-
"""
803-
Context manager that ensures operations are executed within a transaction.
804-
If no transaction is active, creates a new one for the duration of the context.
805-
If a transaction is already active, uses the existing transaction.
806-
"""
807-
if not self.in_transaction:
808-
async with self.transaction():
809-
yield
810-
else:
811-
yield
812-
813814
def _batch_objects_and_values(
814815
self,
815816
objects: Sequence[TableBase],

0 commit comments

Comments
 (0)