13
13
14
14
from types import MethodType
15
15
16
- from .constants import DjangoModelDeletionCollector , DjangoDbRouter
17
16
from .query import MockSet
18
17
19
18
# noinspection PyUnresolvedReferences
@@ -330,10 +329,6 @@ class Mocker:
330
329
A decorator that patches multiple class methods with a magic mock instance that does nothing.
331
330
"""
332
331
333
- shared_mocks = {}
334
- shared_patchers = {}
335
- shared_original = {}
336
-
337
332
def __init__ (self , cls , * methods , ** kwargs ):
338
333
self .cls = cls
339
334
self .methods = methods
@@ -342,8 +337,6 @@ def __init__(self, cls, *methods, **kwargs):
342
337
self .inst_patchers = {}
343
338
self .inst_original = {}
344
339
345
- self .outer = kwargs .get ('outer' , True )
346
-
347
340
def __enter__ (self ):
348
341
self ._patch_object_methods (self .cls , * self .methods )
349
342
return self
@@ -359,11 +352,8 @@ def decorated(*args, **kwargs):
359
352
return decorated
360
353
361
354
def __exit__ (self , exc_type , exc_val , exc_tb ):
362
- for key , patcher in self .inst_patchers .items ():
355
+ for patcher in self .inst_patchers .values ():
363
356
patcher .stop ()
364
- if self .outer :
365
- for key , patcher in self .shared_patchers .items ():
366
- patcher .stop ()
367
357
368
358
def _key (self , method , obj = None ):
369
359
return '{}.{}' .format (obj or self .cls , method )
@@ -374,10 +364,10 @@ def _method_obj(self, name, obj, *sources):
374
364
return d [self ._key (name , obj = obj )]
375
365
376
366
def method (self , name , obj = None ):
377
- return self ._method_obj (name , obj , self .shared_mocks , self . inst_mocks )
367
+ return self ._method_obj (name , obj , self .inst_mocks )
378
368
379
369
def original_method (self , name , obj = None ):
380
- return self ._method_obj (name , obj , self .shared_original , self . inst_original )
370
+ return self ._method_obj (name , obj , self .inst_original )
381
371
382
372
def _get_source_method (self , obj , method ):
383
373
source_obj = obj
@@ -406,28 +396,28 @@ def _patch_method(self, method_name, source_obj, source_method):
406
396
return patch_object (source_obj , source_method , ** mock_args )
407
397
408
398
def _patch_object_methods (self , obj , * methods , ** kwargs ):
409
- if kwargs .get ('shared' , False ):
410
- original , patchers , mocks = self .shared_original , self .shared_patchers , self .shared_mocks
411
- else :
412
- original , patchers , mocks = self .inst_original , self .inst_patchers , self .inst_mocks
399
+ original , patchers , mocks = self .inst_original , self .inst_patchers , self .inst_mocks
413
400
414
401
for method in methods :
415
402
key = self ._key (method , obj = obj )
416
403
417
404
source_obj , source_method = self ._get_source_method (obj , method )
418
- original [key ] = original .get (key , None ) or getattr (source_obj , source_method )
419
405
420
- patcher = self ._patch_method (method , source_obj , source_method )
421
- patchers [key ] = patcher
422
- mocks [key ] = patcher .start ()
406
+ if key not in original :
407
+ original [key ] = getattr (source_obj , source_method )
408
+
409
+ if key not in patchers :
410
+ patcher = self ._patch_method (method , source_obj , source_method )
411
+ patchers [key ] = patcher
412
+ mocks [key ] = patcher .start ()
423
413
424
414
425
415
class ModelMocker (Mocker ):
426
416
"""
427
417
A decorator that patches django base model's db read/write methods and wires them to a MockSet.
428
418
"""
429
419
430
- default_methods = ['objects' , '_do_update' ]
420
+ default_methods = ['objects' , '_do_update' , 'delete' ]
431
421
432
422
if django .VERSION [0 ] >= 3 :
433
423
default_methods += ['_base_manager._insert' , ]
@@ -442,11 +432,8 @@ def __init__(self, cls, *methods, **kwargs):
442
432
self .objects = MockSet (model = self .cls )
443
433
self .objects .on ('added' , self ._on_added )
444
434
445
- self .state = {}
446
-
447
435
def __enter__ (self ):
448
436
result = super (ModelMocker , self ).__enter__ ()
449
- self ._patch_object_methods (DjangoModelDeletionCollector , 'collect' , 'delete' , shared = True )
450
437
return result
451
438
452
439
def _obj_pk (self , obj ):
@@ -482,23 +469,11 @@ def _do_update(self, *args, **_):
482
469
else :
483
470
return False
484
471
485
- def collect (self , objects , * args , ** kwargs ):
486
- model = getattr (objects , 'model' , None ) or objects [0 ]
487
-
488
- if not (model is self .cls or isinstance (model , self .cls )):
489
- using = getattr (objects , 'db' , None ) or DjangoDbRouter .db_for_write (model ._meta .model , instance = model )
490
- self .state ['collector' ] = DjangoModelDeletionCollector (using = using )
491
-
492
- collect = self .original_method ('collect' , obj = DjangoModelDeletionCollector )
493
- collect (self .state ['collector' ], objects , * args , ** kwargs )
494
-
495
- self .state ['model' ] = model
496
-
497
- def delete (self , * args , ** kwargs ):
498
- model = self .state .pop ('model' )
499
-
500
- if not (model is self .cls or isinstance (model , self .cls )):
501
- delete = self .original_method ('delete' , obj = DjangoModelDeletionCollector )
502
- return delete (self .state .pop ('collector' ), * args , ** kwargs )
503
- else :
504
- return self .objects .filter (pk = getattr (model , self .cls ._meta .pk .attname )).delete ()
472
+ def delete (self , * _args , ** _kwargs ):
473
+ pk = self ._obj_pk (self .objects [0 ])
474
+ if not pk :
475
+ raise ValueError (
476
+ f"{ self .cls ._meta .object_name } object can't be deleted because "
477
+ f'its { self .cls ._meta .pk .attname } attribute is set to None.'
478
+ )
479
+ return self .objects .filter (pk = pk ).delete ()
0 commit comments