Skip to content

Commit b18709b

Browse files
committed
Move handling of field name normalisation into a method
1 parent 601386c commit b18709b

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

psqlextra/compiler.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,7 @@ def _build_conflict_target(self):
171171
) % str(self.query.conflict_target))
172172

173173
def _assert_valid_field(field_name):
174-
if isinstance(field_name, tuple):
175-
field_name, _ = field_name
176-
174+
field_name = self._normalize_field_name(field_name)
177175
if self._get_model_field(field_name):
178176
return
179177

@@ -215,9 +213,7 @@ def _get_model_field(self, name: str):
215213
no such field exists.
216214
"""
217215

218-
field_name = name
219-
if isinstance(field_name, tuple):
220-
field_name = field_name[0]
216+
field_name = self._normalize_field_name(name)
221217

222218
for field in self.query.model._meta.local_concrete_fields:
223219
if field.name == field_name or field.column == field_name:
@@ -253,12 +249,29 @@ def _format_field_value(self, field_name) -> str:
253249
in SQL.
254250
"""
255251

256-
if isinstance(field_name, tuple):
257-
field_name, _ = field_name
258-
252+
field_name = self._normalize_field_name(field_name)
259253
field = self._get_model_field(field_name)
254+
260255
return SQLInsertCompiler.prepare_value(
261256
self,
262257
field,
263258
getattr(self.query.objs[0], field_name)
264259
)
260+
261+
def _normalize_field_name(self, field_name) -> str:
262+
"""Normalizes a field name into a string by
263+
extracting the field name if it was specified
264+
as a reference to a HStore key (as a tuple).
265+
266+
Arguments:
267+
field_name:
268+
The field name to normalize.
269+
270+
Returns:
271+
The normalized field name.
272+
"""
273+
274+
if isinstance(field_name, tuple):
275+
field_name, _ = field_name
276+
277+
return field_name

tests/test_on_conflict.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,9 @@ def test_on_conflict_unique_together2(conflict_action):
270270
"""Asserts that inserts on models with a unique_together
271271
works properly."""
272272

273-
model = get_fake_model(
274-
{
275-
'name': models.CharField(max_length=140),
276-
},
277-
)
273+
model = get_fake_model({
274+
'name': models.CharField(max_length=140)
275+
})
278276

279277
model2 = get_fake_model(
280278
{
@@ -294,14 +292,14 @@ def test_on_conflict_unique_together2(conflict_action):
294292

295293
id3 = (
296294
model2.objects
297-
.on_conflict(['model1_id', 'model2_id'], conflict_action)
298-
.insert(model1_id=id1, model2_id=id2)
295+
.on_conflict(['model1_id', 'model2_id'], conflict_action)
296+
.insert(model1_id=id1, model2_id=id2)
299297
)
300298

301299
id4 = (
302300
model2.objects
303-
.on_conflict(['model1_id', 'model2_id'], conflict_action)
304-
.insert(model1_id=id1, model2_id=id2)
301+
.on_conflict(['model1_id', 'model2_id'], conflict_action)
302+
.insert(model1_id=id1, model2_id=id2)
305303
)
306304

307305
assert id3 == id4

0 commit comments

Comments
 (0)