Skip to content

Commit 045e924

Browse files
Merge pull request #53 from piercefreeman/bug/fix-foreign-key-dependencies
Add destination table/col to migration dependencies
2 parents 7ee95cd + fb8c1d5 commit 045e924

File tree

5 files changed

+217
-19
lines changed

5 files changed

+217
-19
lines changed

iceaxe/__tests__/schemas/test_db_memory_serializer.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,3 +1314,115 @@ class AutoAssignedModel(TableBase):
13141314
)
13151315
assert not auto_id_column.nullable
13161316
assert auto_id_column.autoincrement
1317+
1318+
1319+
@pytest.mark.asyncio
1320+
async def test_foreign_key_table_dependency():
1321+
"""
1322+
Test that foreign key constraints properly depend on the referenced table being created first.
1323+
This test verifies that the foreign key constraint is ordered after both tables are created.
1324+
"""
1325+
1326+
class TargetModel(TableBase):
1327+
id: int = Field(primary_key=True)
1328+
value: str
1329+
1330+
class SourceModel(TableBase):
1331+
id: int = Field(primary_key=True)
1332+
target_id: int = Field(foreign_key="targetmodel.id")
1333+
1334+
migrator = DatabaseMemorySerializer()
1335+
1336+
# Make sure Source is parsed before Target so we can make sure our foreign-key
1337+
# constraint actually re-orders the final objects.
1338+
db_objects = list(migrator.delegate([SourceModel, TargetModel]))
1339+
ordering = migrator.order_db_objects(db_objects)
1340+
1341+
# Get all objects in their sorted order
1342+
sorted_objects = sorted(
1343+
[obj for obj, _ in db_objects], key=lambda obj: ordering[obj]
1344+
)
1345+
1346+
# Find the positions of key objects
1347+
target_table_pos = next(
1348+
i
1349+
for i, obj in enumerate(sorted_objects)
1350+
if isinstance(obj, DBTable) and obj.table_name == "targetmodel"
1351+
)
1352+
source_table_pos = next(
1353+
i
1354+
for i, obj in enumerate(sorted_objects)
1355+
if isinstance(obj, DBTable) and obj.table_name == "sourcemodel"
1356+
)
1357+
target_column_pos = next(
1358+
i
1359+
for i, obj in enumerate(sorted_objects)
1360+
if isinstance(obj, DBColumn)
1361+
and obj.table_name == "targetmodel"
1362+
and obj.column_name == "id"
1363+
)
1364+
target_pk_pos = next(
1365+
i
1366+
for i, obj in enumerate(sorted_objects)
1367+
if isinstance(obj, DBConstraint)
1368+
and obj.constraint_type == ConstraintType.PRIMARY_KEY
1369+
and obj.table_name == "targetmodel"
1370+
)
1371+
fk_constraint_pos = next(
1372+
i
1373+
for i, obj in enumerate(sorted_objects)
1374+
if isinstance(obj, DBConstraint)
1375+
and obj.constraint_type == ConstraintType.FOREIGN_KEY
1376+
and obj.table_name == "sourcemodel"
1377+
)
1378+
1379+
# The foreign key constraint should come after both tables and the target column are created
1380+
assert (
1381+
target_table_pos < fk_constraint_pos
1382+
), "Foreign key constraint should be created after target table"
1383+
assert (
1384+
source_table_pos < fk_constraint_pos
1385+
), "Foreign key constraint should be created after source table"
1386+
assert (
1387+
target_column_pos < fk_constraint_pos
1388+
), "Foreign key constraint should be created after target column"
1389+
assert (
1390+
target_pk_pos < fk_constraint_pos
1391+
), "Foreign key constraint should be created after target primary key"
1392+
1393+
# Verify the actual migration actions
1394+
actor = DatabaseActions()
1395+
actions = await migrator.build_actions(
1396+
actor, [], {}, [obj for obj, _ in db_objects], ordering
1397+
)
1398+
1399+
# Extract the table creation and foreign key constraint actions
1400+
table_creations = [
1401+
action
1402+
for action in actions
1403+
if isinstance(action, DryRunAction) and action.fn == actor.add_table
1404+
]
1405+
fk_constraints = [
1406+
action
1407+
for action in actions
1408+
if isinstance(action, DryRunAction)
1409+
and action.fn == actor.add_constraint
1410+
and action.kwargs.get("constraint") == ConstraintType.FOREIGN_KEY
1411+
]
1412+
1413+
# Verify that table creations come before foreign key constraints
1414+
assert len(table_creations) == 2
1415+
assert len(fk_constraints) == 1
1416+
1417+
table_creation_indices = [
1418+
i for i, action in enumerate(actions) if action in table_creations
1419+
]
1420+
fk_constraint_indices = [
1421+
i for i, action in enumerate(actions) if action in fk_constraints
1422+
]
1423+
1424+
assert all(
1425+
table_idx < fk_idx
1426+
for table_idx in table_creation_indices
1427+
for fk_idx in fk_constraint_indices
1428+
)

iceaxe/migrations/action_sorter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, graph: dict[DBObject, list[DBObject]]):
2020
self.in_degree = defaultdict(int)
2121
self.nodes = set(graph.keys())
2222

23-
for node, dependencies in graph.items():
23+
for node, dependencies in list(graph.items()):
2424
for dep in dependencies:
2525
self.in_degree[node] += 1
2626
if dep not in self.nodes:

iceaxe/migrations/cli.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from iceaxe.base import DBModelMetaclass, TableBase
55
from iceaxe.io import resolve_package_path
66
from iceaxe.logging import CONSOLE
7-
from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
8-
from iceaxe.migrations.generator import MigrationGenerator
9-
from iceaxe.migrations.migrator import Migrator
107
from iceaxe.schemas.db_serializer import DatabaseSerializer
118
from iceaxe.session import DBConnection
129

@@ -35,6 +32,11 @@ def generate_migration(message: str):
3532
handle_generate("my_project", db_connection, message=message)
3633
```
3734
"""
35+
# Any local imports must be done here to avoid circular imports because migrations.__init__
36+
# imports this file.
37+
from iceaxe.migrations.client_io import fetch_migrations
38+
from iceaxe.migrations.generator import MigrationGenerator
39+
from iceaxe.migrations.migrator import Migrator
3840

3941
CONSOLE.print("[bold blue]Generating migration to current schema")
4042

@@ -122,6 +124,8 @@ async def handle_apply(
122124
project that's specified in pyproject.toml or setup.py.
123125
124126
"""
127+
from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
128+
from iceaxe.migrations.migrator import Migrator
125129

126130
migrations_path = resolve_package_path(package) / "migrations"
127131
if not migrations_path.exists():
@@ -178,6 +182,8 @@ async def handle_rollback(
178182
project that's specified in pyproject.toml or setup.py.
179183
180184
"""
185+
from iceaxe.migrations.client_io import fetch_migrations, sort_migrations
186+
from iceaxe.migrations.migrator import Migrator
181187

182188
migrations_path = resolve_package_path(package) / "migrations"
183189
if not migrations_path.exists():

iceaxe/schemas/db_memory_serializer.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from datetime import date, datetime, time, timedelta
33
from inspect import isgenerator
4-
from typing import Any, Generator, Sequence, Type, TypeVar
4+
from typing import Any, Generator, Sequence, Type, TypeVar, Union
55
from uuid import UUID
66

77
from pydantic_core import PydanticUndefined
@@ -29,9 +29,12 @@
2929
)
3030
from iceaxe.schemas.db_stubs import (
3131
DBColumn,
32+
DBColumnPointer,
3233
DBConstraint,
34+
DBConstraintPointer,
3335
DBObject,
3436
DBObjectPointer,
37+
DBPointerOr,
3538
DBTable,
3639
DBType,
3740
DBTypePointer,
@@ -44,11 +47,13 @@
4447
PRIMITIVE_WRAPPER_TYPES,
4548
)
4649

50+
NodeYieldType = Union[DBObject, DBObjectPointer, "NodeDefinition"]
51+
4752

4853
@dataclass
4954
class NodeDefinition:
5055
node: DBObject
51-
dependencies: list[DBObject]
56+
dependencies: list[DBObject | DBObjectPointer]
5257
force_no_dependencies: bool
5358

5459

@@ -110,18 +115,42 @@ def order_db_objects(
110115
for _, dependencies in db_objects:
111116
for dep in dependencies:
112117
if isinstance(dep, DBObjectPointer):
113-
if dep.representation() not in db_objects_by_name:
118+
if isinstance(dep, DBPointerOr):
119+
# For OR pointers, at least one of the pointers must be resolvable
120+
if not any(
121+
pointer.representation() in db_objects_by_name
122+
for pointer in dep.pointers
123+
):
124+
raise ValueError(
125+
f"None of the OR pointers {[p.representation() for p in dep.pointers]} found in the defined database objects"
126+
)
127+
elif dep.representation() not in db_objects_by_name:
114128
raise ValueError(
115129
f"Pointer {dep.representation()} not found in the defined database objects"
116130
)
117131

118132
# Map the potentially different objects to the same object
119-
graph_edges = {
120-
db_objects_by_name[obj.representation()]: [
121-
db_objects_by_name[dep.representation()] for dep in dependencies
122-
]
123-
for obj, dependencies in db_objects
124-
}
133+
graph_edges = {}
134+
for obj, dependencies in db_objects:
135+
resolved_deps = []
136+
for dep in dependencies:
137+
if isinstance(dep, DBObjectPointer):
138+
if isinstance(dep, DBPointerOr):
139+
# Add all resolvable pointers as dependencies
140+
resolved_deps.extend(
141+
db_objects_by_name[pointer.representation()]
142+
for pointer in dep.pointers
143+
if pointer.representation() in db_objects_by_name
144+
)
145+
else:
146+
resolved_deps.append(db_objects_by_name[dep.representation()])
147+
else:
148+
resolved_deps.append(dep)
149+
150+
if isinstance(obj, DBObjectPointer):
151+
continue
152+
153+
graph_edges[db_objects_by_name[obj.representation()]] = resolved_deps
125154

126155
# Construct the directed acyclic graph
127156
ts = ActionTopologicalSorter(graph_edges)
@@ -217,9 +246,6 @@ def migrate(self, previous, actor: DatabaseActions):
217246
raise NotImplementedError()
218247

219248

220-
NodeYieldType = DBObject | NodeDefinition
221-
222-
223249
class DatabaseHandler:
224250
def __init__(self):
225251
self.python_to_sql = {
@@ -430,7 +456,32 @@ def _build_constraint(
430456
target_table=target_table,
431457
target_columns=frozenset({target_column}),
432458
),
433-
)
459+
),
460+
dependencies=[
461+
# Additional dependencies to ensure the target table/column is created first
462+
DBTable(table_name=target_table),
463+
DBColumnPointer(
464+
table_name=target_table,
465+
column_name=target_column,
466+
),
467+
# Ensure the primary key constraint exists before the foreign key
468+
# constraint. Postgres also accepts a unique constraint on the same.
469+
DBPointerOr(
470+
pointers=tuple(
471+
[
472+
DBConstraintPointer(
473+
table_name=target_table,
474+
columns=frozenset([target_column]),
475+
constraint_type=constraint_type,
476+
)
477+
for constraint_type in [
478+
ConstraintType.PRIMARY_KEY,
479+
ConstraintType.UNIQUE,
480+
]
481+
]
482+
),
483+
),
484+
],
434485
)
435486

436487
if info.index:
@@ -509,10 +560,10 @@ def _yield_nodes(
509560
"""
510561

511562
def _format_dependencies(dependencies: Sequence[NodeYieldType]):
512-
all_dependencies: list[DBObject] = []
563+
all_dependencies: list[DBObject | DBObjectPointer] = []
513564

514565
for value in dependencies:
515-
if isinstance(value, DBObject):
566+
if isinstance(value, (DBObject, DBObjectPointer)):
516567
all_dependencies.append(value)
517568
elif isinstance(value, NodeDefinition):
518569
all_dependencies.append(value.node)

iceaxe/schemas/db_stubs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,32 @@ def merge(self, other: "DBType") -> "DBType":
374374
values=self.values,
375375
reference_columns=self.reference_columns | other.reference_columns,
376376
)
377+
378+
379+
class DBConstraintPointer(DBObjectPointer):
380+
"""
381+
A pointer to a constraint that will be created. Used for dependency tracking
382+
without needing to know the full constraint definition.
383+
"""
384+
385+
table_name: str
386+
columns: frozenset[str]
387+
constraint_type: ConstraintType
388+
389+
def representation(self) -> str:
390+
# Match the representation of DBConstraint
391+
return f"{self.table_name}.{sorted(self.columns)}.{self.constraint_type}"
392+
393+
394+
class DBPointerOr(DBObjectPointer):
395+
"""
396+
A pointer that represents an OR relationship between multiple pointers.
397+
When resolving dependencies, any of the provided pointers being present
398+
will satisfy the dependency.
399+
"""
400+
401+
pointers: tuple[DBObjectPointer, ...]
402+
403+
def representation(self) -> str:
404+
# Sort the representations to ensure consistent ordering
405+
return "OR(" + ",".join(sorted(p.representation() for p in self.pointers)) + ")"

0 commit comments

Comments
 (0)