Skip to content

Commit 1b364eb

Browse files
author
Nicolás Fantone
committed
Fix delete on instances using nested ModelMocker managers
1 parent 8736200 commit 1b364eb

File tree

2 files changed

+41
-45
lines changed

2 files changed

+41
-45
lines changed

django_mock_queries/mocks.py

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from types import MethodType
1515

16-
from .constants import DjangoModelDeletionCollector, DjangoDbRouter
1716
from .query import MockSet
1817

1918
# noinspection PyUnresolvedReferences
@@ -330,10 +329,6 @@ class Mocker:
330329
A decorator that patches multiple class methods with a magic mock instance that does nothing.
331330
"""
332331

333-
shared_mocks = {}
334-
shared_patchers = {}
335-
shared_original = {}
336-
337332
def __init__(self, cls, *methods, **kwargs):
338333
self.cls = cls
339334
self.methods = methods
@@ -342,8 +337,6 @@ def __init__(self, cls, *methods, **kwargs):
342337
self.inst_patchers = {}
343338
self.inst_original = {}
344339

345-
self.outer = kwargs.get('outer', True)
346-
347340
def __enter__(self):
348341
self._patch_object_methods(self.cls, *self.methods)
349342
return self
@@ -359,11 +352,8 @@ def decorated(*args, **kwargs):
359352
return decorated
360353

361354
def __exit__(self, exc_type, exc_val, exc_tb):
362-
for key, patcher in self.inst_patchers.items():
355+
for patcher in self.inst_patchers.values():
363356
patcher.stop()
364-
if self.outer:
365-
for key, patcher in self.shared_patchers.items():
366-
patcher.stop()
367357

368358
def _key(self, method, obj=None):
369359
return '{}.{}'.format(obj or self.cls, method)
@@ -374,10 +364,10 @@ def _method_obj(self, name, obj, *sources):
374364
return d[self._key(name, obj=obj)]
375365

376366
def method(self, name, obj=None):
377-
return self._method_obj(name, obj, self.shared_mocks, self.inst_mocks)
367+
return self._method_obj(name, obj, self.inst_mocks)
378368

379369
def original_method(self, name, obj=None):
380-
return self._method_obj(name, obj, self.shared_original, self.inst_original)
370+
return self._method_obj(name, obj, self.inst_original)
381371

382372
def _get_source_method(self, obj, method):
383373
source_obj = obj
@@ -406,28 +396,28 @@ def _patch_method(self, method_name, source_obj, source_method):
406396
return patch_object(source_obj, source_method, **mock_args)
407397

408398
def _patch_object_methods(self, obj, *methods, **kwargs):
409-
if kwargs.get('shared', False):
410-
original, patchers, mocks = self.shared_original, self.shared_patchers, self.shared_mocks
411-
else:
412-
original, patchers, mocks = self.inst_original, self.inst_patchers, self.inst_mocks
399+
original, patchers, mocks = self.inst_original, self.inst_patchers, self.inst_mocks
413400

414401
for method in methods:
415402
key = self._key(method, obj=obj)
416403

417404
source_obj, source_method = self._get_source_method(obj, method)
418-
original[key] = original.get(key, None) or getattr(source_obj, source_method)
419405

420-
patcher = self._patch_method(method, source_obj, source_method)
421-
patchers[key] = patcher
422-
mocks[key] = patcher.start()
406+
if key not in original:
407+
original[key] = getattr(source_obj, source_method)
408+
409+
if key not in patchers:
410+
patcher = self._patch_method(method, source_obj, source_method)
411+
patchers[key] = patcher
412+
mocks[key] = patcher.start()
423413

424414

425415
class ModelMocker(Mocker):
426416
"""
427417
A decorator that patches django base model's db read/write methods and wires them to a MockSet.
428418
"""
429419

430-
default_methods = ['objects', '_do_update']
420+
default_methods = ['objects', '_do_update', 'delete']
431421

432422
if django.VERSION[0] >= 3:
433423
default_methods += ['_base_manager._insert', ]
@@ -442,11 +432,8 @@ def __init__(self, cls, *methods, **kwargs):
442432
self.objects = MockSet(model=self.cls)
443433
self.objects.on('added', self._on_added)
444434

445-
self.state = {}
446-
447435
def __enter__(self):
448436
result = super(ModelMocker, self).__enter__()
449-
self._patch_object_methods(DjangoModelDeletionCollector, 'collect', 'delete', shared=True)
450437
return result
451438

452439
def _obj_pk(self, obj):
@@ -482,23 +469,11 @@ def _do_update(self, *args, **_):
482469
else:
483470
return False
484471

485-
def collect(self, objects, *args, **kwargs):
486-
model = getattr(objects, 'model', None) or objects[0]
487-
488-
if not (model is self.cls or isinstance(model, self.cls)):
489-
using = getattr(objects, 'db', None) or DjangoDbRouter.db_for_write(model._meta.model, instance=model)
490-
self.state['collector'] = DjangoModelDeletionCollector(using=using)
491-
492-
collect = self.original_method('collect', obj=DjangoModelDeletionCollector)
493-
collect(self.state['collector'], objects, *args, **kwargs)
494-
495-
self.state['model'] = model
496-
497-
def delete(self, *args, **kwargs):
498-
model = self.state.pop('model')
499-
500-
if not (model is self.cls or isinstance(model, self.cls)):
501-
delete = self.original_method('delete', obj=DjangoModelDeletionCollector)
502-
return delete(self.state.pop('collector'), *args, **kwargs)
503-
else:
504-
return self.objects.filter(pk=getattr(model, self.cls._meta.pk.attname)).delete()
472+
def delete(self, *_args, **_kwargs):
473+
pk = self._obj_pk(self.objects[0])
474+
if not pk:
475+
raise ValueError(
476+
f"{self.cls._meta.object_name} object can't be deleted because "
477+
f'its {self.cls._meta.pk.attname} attribute is set to None.'
478+
)
479+
return self.objects.filter(pk=pk).delete()

tests/test_mocks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,27 @@ def car_added(obj):
486486
self.assertIsInstance(objects['added'], Car)
487487
self.assertEqual(objects['added'], objects['car'])
488488

489+
def test_model_mocker_delete_from_instance_with_nested_context_manager(self):
490+
491+
def create_delete_models():
492+
car = Car.objects.create(speed=10)
493+
car.delete()
494+
495+
manufacturer = Manufacturer.objects.create(name='foo')
496+
manufacturer.delete()
497+
498+
def models_exist():
499+
return Manufacturer.objects.exists() or Car.objects.exists()
500+
501+
with ModelMocker(Manufacturer), ModelMocker(Car):
502+
create_delete_models()
503+
assert not models_exist()
504+
505+
# Test same scenario with reversed context manager order
506+
with ModelMocker(Car), ModelMocker(Manufacturer):
507+
create_delete_models()
508+
assert not models_exist()
509+
489510
def test_model_mocker_event_updated_from_manager(self):
490511
objects = {}
491512

0 commit comments

Comments
 (0)