Skip to content

Commit 6fb7180

Browse files
committed
Fix on_conflict for models with unique_together + tests
1 parent b08de63 commit 6fb7180

File tree

4 files changed

+41
-8
lines changed

4 files changed

+41
-8
lines changed

psqlextra/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _rewrite_insert_nothing(self, sql, params, returning):
124124
# for conflicts
125125
conflict_target = self._build_conflict_target()
126126

127-
where_clause = ', '.join([
127+
where_clause = ' AND '.join([
128128
'{0} = %s'.format(self._format_field_name(field_name))
129129
for field_name in self.query.conflict_target
130130
])
@@ -143,7 +143,7 @@ def _rewrite_insert_nothing(self, sql, params, returning):
143143
return (
144144
(
145145
'WITH insdata AS ('
146-
'{insert} ON CONFLICT ({conflict_target}) DO UPDATE'
146+
'{insert} ON CONFLICT {conflict_target} DO UPDATE'
147147
' SET id = NULL WHERE FALSE RETURNING {returning})'
148148
' SELECT * FROM insdata UNION ALL'
149149
' SELECT {returning} FROM {table} WHERE {where_clause} LIMIT 1;'

tests/fake_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from psqlextra.models import PostgresModel
88

99

10-
def define_fake_model(fields=None, model_base=PostgresModel):
10+
def define_fake_model(fields=None, model_base=PostgresModel, meta_options={}):
1111
name = str(uuid.uuid4()).replace('-', '')[:8]
1212

1313
attributes = {
1414
'app_label': 'tests',
1515
'__module__': __name__,
16-
'__name__': name
16+
'__name__': name,
17+
'Meta': type('Meta', (object,), meta_options)
1718
}
1819

1920
if fields:
@@ -23,10 +24,10 @@ def define_fake_model(fields=None, model_base=PostgresModel):
2324
return model
2425

2526

26-
def get_fake_model(fields=None, model_base=PostgresModel):
27+
def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}):
2728
"""Creates a fake model to use during unit tests."""
2829

29-
model = define_fake_model(fields, model_base)
30+
model = define_fake_model(fields, model_base, meta_options)
3031

3132
class TestProject:
3233

tests/test_on_conflict.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from psqlextra import HStoreField
66
from psqlextra.query import ConflictAction
7+
from psqlextra.models import PostgresModel
78

89
from .fake_model import get_fake_model
910

@@ -215,6 +216,7 @@ def test_on_conflict_outdated_model(conflict_action):
215216
.insert_and_get(title='beer')
216217
)
217218

219+
218220
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
219221
def test_on_conflict_custom_column_names(conflict_action):
220222
"""Asserts that models with custom column names (models
@@ -225,8 +227,39 @@ def test_on_conflict_custom_column_names(conflict_action):
225227
'description': models.CharField(max_length=255, db_column='desc')
226228
})
227229

228-
id = (
230+
(
229231
model.objects
230232
.on_conflict(['title'], conflict_action)
231233
.insert(title='yeey', description='great thing')
232234
)
235+
236+
237+
@pytest.mark.parametrize("conflict_action", CONFLICT_ACTIONS)
238+
def test_on_conflict_unique_together(conflict_action):
239+
"""Asserts that inserts on models with a unique_together
240+
works properly."""
241+
242+
model = get_fake_model(
243+
{
244+
'first_name': models.CharField(max_length=140),
245+
'last_name': models.CharField(max_length=255)
246+
},
247+
PostgresModel,
248+
{
249+
'unique_together': ('first_name', 'last_name')
250+
}
251+
)
252+
253+
id1 = (
254+
model.objects
255+
.on_conflict(['first_name', 'last_name'], conflict_action)
256+
.insert(first_name='swen', last_name='kooij')
257+
)
258+
259+
id2 = (
260+
model.objects
261+
.on_conflict(['first_name', 'last_name'], conflict_action)
262+
.insert(first_name='swen', last_name='kooij')
263+
)
264+
265+
assert id1 == id2

tests/test_upsert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from django.db import models
22

33
from psqlextra import HStoreField
4-
from psqlextra.query import ConflictAction
54

65
from .fake_model import get_fake_model
76

0 commit comments

Comments
 (0)