Skip to content

Commit

Permalink
Fix delete on instances using nested ModelMocker managers (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfantone authored Sep 11, 2024
1 parent 8736200 commit 978e030
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 45 deletions.
65 changes: 20 additions & 45 deletions django_mock_queries/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from types import MethodType

from .constants import DjangoModelDeletionCollector, DjangoDbRouter
from .query import MockSet

# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -330,10 +329,6 @@ class Mocker:
A decorator that patches multiple class methods with a magic mock instance that does nothing.
"""

shared_mocks = {}
shared_patchers = {}
shared_original = {}

def __init__(self, cls, *methods, **kwargs):
self.cls = cls
self.methods = methods
Expand All @@ -342,8 +337,6 @@ def __init__(self, cls, *methods, **kwargs):
self.inst_patchers = {}
self.inst_original = {}

self.outer = kwargs.get('outer', True)

def __enter__(self):
self._patch_object_methods(self.cls, *self.methods)
return self
Expand All @@ -359,11 +352,8 @@ def decorated(*args, **kwargs):
return decorated

def __exit__(self, exc_type, exc_val, exc_tb):
for key, patcher in self.inst_patchers.items():
for patcher in self.inst_patchers.values():
patcher.stop()
if self.outer:
for key, patcher in self.shared_patchers.items():
patcher.stop()

def _key(self, method, obj=None):
return '{}.{}'.format(obj or self.cls, method)
Expand All @@ -374,10 +364,10 @@ def _method_obj(self, name, obj, *sources):
return d[self._key(name, obj=obj)]

def method(self, name, obj=None):
return self._method_obj(name, obj, self.shared_mocks, self.inst_mocks)
return self._method_obj(name, obj, self.inst_mocks)

def original_method(self, name, obj=None):
return self._method_obj(name, obj, self.shared_original, self.inst_original)
return self._method_obj(name, obj, self.inst_original)

def _get_source_method(self, obj, method):
source_obj = obj
Expand Down Expand Up @@ -406,28 +396,28 @@ def _patch_method(self, method_name, source_obj, source_method):
return patch_object(source_obj, source_method, **mock_args)

def _patch_object_methods(self, obj, *methods, **kwargs):
if kwargs.get('shared', False):
original, patchers, mocks = self.shared_original, self.shared_patchers, self.shared_mocks
else:
original, patchers, mocks = self.inst_original, self.inst_patchers, self.inst_mocks
original, patchers, mocks = self.inst_original, self.inst_patchers, self.inst_mocks

for method in methods:
key = self._key(method, obj=obj)

source_obj, source_method = self._get_source_method(obj, method)
original[key] = original.get(key, None) or getattr(source_obj, source_method)

patcher = self._patch_method(method, source_obj, source_method)
patchers[key] = patcher
mocks[key] = patcher.start()
if key not in original:
original[key] = getattr(source_obj, source_method)

if key not in patchers:
patcher = self._patch_method(method, source_obj, source_method)
patchers[key] = patcher
mocks[key] = patcher.start()


class ModelMocker(Mocker):
"""
A decorator that patches django base model's db read/write methods and wires them to a MockSet.
"""

default_methods = ['objects', '_do_update']
default_methods = ['objects', '_do_update', 'delete']

if django.VERSION[0] >= 3:
default_methods += ['_base_manager._insert', ]
Expand All @@ -442,11 +432,8 @@ def __init__(self, cls, *methods, **kwargs):
self.objects = MockSet(model=self.cls)
self.objects.on('added', self._on_added)

self.state = {}

def __enter__(self):
result = super(ModelMocker, self).__enter__()
self._patch_object_methods(DjangoModelDeletionCollector, 'collect', 'delete', shared=True)
return result

def _obj_pk(self, obj):
Expand Down Expand Up @@ -482,23 +469,11 @@ def _do_update(self, *args, **_):
else:
return False

def collect(self, objects, *args, **kwargs):
model = getattr(objects, 'model', None) or objects[0]

if not (model is self.cls or isinstance(model, self.cls)):
using = getattr(objects, 'db', None) or DjangoDbRouter.db_for_write(model._meta.model, instance=model)
self.state['collector'] = DjangoModelDeletionCollector(using=using)

collect = self.original_method('collect', obj=DjangoModelDeletionCollector)
collect(self.state['collector'], objects, *args, **kwargs)

self.state['model'] = model

def delete(self, *args, **kwargs):
model = self.state.pop('model')

if not (model is self.cls or isinstance(model, self.cls)):
delete = self.original_method('delete', obj=DjangoModelDeletionCollector)
return delete(self.state.pop('collector'), *args, **kwargs)
else:
return self.objects.filter(pk=getattr(model, self.cls._meta.pk.attname)).delete()
def delete(self, *_args, **_kwargs):
pk = self._obj_pk(self.objects[0])
if not pk:
raise ValueError(
f"{self.cls._meta.object_name} object can't be deleted because "
f'its {self.cls._meta.pk.attname} attribute is set to None.'
)
return self.objects.filter(pk=pk).delete()
21 changes: 21 additions & 0 deletions tests/test_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,27 @@ def car_added(obj):
self.assertIsInstance(objects['added'], Car)
self.assertEqual(objects['added'], objects['car'])

def test_model_mocker_delete_from_instance_with_nested_context_manager(self):

def create_delete_models():
car = Car.objects.create(speed=10)
car.delete()

manufacturer = Manufacturer.objects.create(name='foo')
manufacturer.delete()

def models_exist():
return Manufacturer.objects.exists() or Car.objects.exists()

with ModelMocker(Manufacturer), ModelMocker(Car):
create_delete_models()
assert not models_exist()

# Test same scenario with reversed context manager order
with ModelMocker(Car), ModelMocker(Manufacturer):
create_delete_models()
assert not models_exist()

def test_model_mocker_event_updated_from_manager(self):
objects = {}

Expand Down

0 comments on commit 978e030

Please sign in to comment.