5
5
import sys
6
6
import unittest .mock
7
7
import warnings
8
+ from dataclasses import dataclass
9
+ from dataclasses import field
8
10
from typing import Any
9
11
from typing import Callable
10
12
from typing import Dict
11
13
from typing import Generator
12
14
from typing import Iterable
15
+ from typing import Iterator
13
16
from typing import List
14
17
from typing import Mapping
15
18
from typing import Optional
@@ -43,16 +46,55 @@ class PytestMockWarning(UserWarning):
43
46
"""Base class for all warnings emitted by pytest-mock."""
44
47
45
48
49
+ @dataclass
50
+ class MockCacheItem :
51
+ mock : MockType
52
+ patch : Optional [Any ] = None
53
+
54
+
55
+ @dataclass
56
+ class MockCache :
57
+ cache : List [MockCacheItem ] = field (default_factory = list )
58
+
59
+ def find (self , mock : MockType ) -> MockCacheItem :
60
+ the_mock = next (
61
+ (mock_item for mock_item in self .cache if mock_item .mock == mock ), None
62
+ )
63
+ if the_mock is None :
64
+ raise ValueError ("This mock object is not registered" )
65
+ return the_mock
66
+
67
+ def add (self , mock : MockType , ** kwargs : Any ) -> MockCacheItem :
68
+ try :
69
+ return self .find (mock )
70
+ except ValueError :
71
+ self .cache .append (MockCacheItem (mock = mock , ** kwargs ))
72
+ return self .cache [- 1 ]
73
+
74
+ def remove (self , mock : MockType ) -> None :
75
+ mock_item = self .find (mock )
76
+ self .cache .remove (mock_item )
77
+
78
+ def clear (self ) -> None :
79
+ self .cache .clear ()
80
+
81
+ def __iter__ (self ) -> Iterator [MockCacheItem ]:
82
+ return iter (self .cache )
83
+
84
+ def __reversed__ (self ) -> Iterator [MockCacheItem ]:
85
+ return reversed (self .cache )
86
+
87
+
46
88
class MockerFixture :
47
89
"""
48
90
Fixture that provides the same interface to functions in the mock module,
49
91
ensuring that they are uninstalled at the end of each test.
50
92
"""
51
93
52
94
def __init__ (self , config : Any ) -> None :
53
- self ._patches_and_mocks : List [ Tuple [ Any , unittest . mock . MagicMock ]] = []
95
+ self ._mock_cache : MockCache = MockCache ()
54
96
self .mock_module = mock_module = get_mock_module (config )
55
- self .patch = self ._Patcher (self ._patches_and_mocks , mock_module ) # type: MockerFixture._Patcher
97
+ self .patch = self ._Patcher (self ._mock_cache , mock_module ) # type: MockerFixture._Patcher
56
98
# aliases for convenience
57
99
self .Mock = mock_module .Mock
58
100
self .MagicMock = mock_module .MagicMock
@@ -75,7 +117,7 @@ def create_autospec(
75
117
m : MockType = self .mock_module .create_autospec (
76
118
spec , spec_set , instance , ** kwargs
77
119
)
78
- self ._patches_and_mocks . append (( None , m ) )
120
+ self ._mock_cache . add ( m )
79
121
return m
80
122
81
123
def resetall (
@@ -93,37 +135,39 @@ def resetall(
93
135
else :
94
136
supports_reset_mock_with_args = (self .Mock ,)
95
137
96
- for p , m in self ._patches_and_mocks :
138
+ for mock_item in self ._mock_cache :
97
139
# See issue #237.
98
- if not hasattr (m , "reset_mock" ):
140
+ if not hasattr (mock_item . mock , "reset_mock" ):
99
141
continue
100
- if isinstance (m , supports_reset_mock_with_args ):
101
- m .reset_mock (return_value = return_value , side_effect = side_effect )
142
+ # NOTE: The mock may be a dictionary
143
+ if hasattr (mock_item .mock , "spy_return_list" ):
144
+ mock_item .mock .spy_return_list = []
145
+ if isinstance (mock_item .mock , supports_reset_mock_with_args ):
146
+ mock_item .mock .reset_mock (
147
+ return_value = return_value , side_effect = side_effect
148
+ )
102
149
else :
103
- m .reset_mock ()
150
+ mock_item . mock .reset_mock ()
104
151
105
152
def stopall (self ) -> None :
106
153
"""
107
154
Stop all patchers started by this fixture. Can be safely called multiple
108
155
times.
109
156
"""
110
- for p , m in reversed (self ._patches_and_mocks ):
111
- if p is not None :
112
- p .stop ()
113
- self ._patches_and_mocks .clear ()
157
+ for mock_item in reversed (self ._mock_cache ):
158
+ if mock_item . patch is not None :
159
+ mock_item . patch .stop ()
160
+ self ._mock_cache .clear ()
114
161
115
162
def stop (self , mock : unittest .mock .MagicMock ) -> None :
116
163
"""
117
164
Stops a previous patch or spy call by passing the ``MagicMock`` object
118
165
returned by it.
119
166
"""
120
- for index , (p , m ) in enumerate (self ._patches_and_mocks ):
121
- if mock is m :
122
- p .stop ()
123
- del self ._patches_and_mocks [index ]
124
- break
125
- else :
126
- raise ValueError ("This mock object is not registered" )
167
+ mock_item = self ._mock_cache .find (mock )
168
+ if mock_item .patch :
169
+ mock_item .patch .stop ()
170
+ self ._mock_cache .remove (mock )
127
171
128
172
def spy (self , obj : object , name : str ) -> MockType :
129
173
"""
@@ -146,6 +190,7 @@ def wrapper(*args, **kwargs):
146
190
raise
147
191
else :
148
192
spy_obj .spy_return = r
193
+ spy_obj .spy_return_list .append (r )
149
194
return r
150
195
151
196
async def async_wrapper (* args , ** kwargs ):
@@ -158,6 +203,7 @@ async def async_wrapper(*args, **kwargs):
158
203
raise
159
204
else :
160
205
spy_obj .spy_return = r
206
+ spy_obj .spy_return_list .append (r )
161
207
return r
162
208
163
209
if asyncio .iscoroutinefunction (method ):
@@ -169,6 +215,7 @@ async def async_wrapper(*args, **kwargs):
169
215
170
216
spy_obj = self .patch .object (obj , name , side_effect = wrapped , autospec = autospec )
171
217
spy_obj .spy_return = None
218
+ spy_obj .spy_return_list = []
172
219
spy_obj .spy_exception = None
173
220
return spy_obj
174
221
@@ -206,8 +253,8 @@ class _Patcher:
206
253
207
254
DEFAULT = object ()
208
255
209
- def __init__ (self , patches_and_mocks , mock_module ):
210
- self .__patches_and_mocks = patches_and_mocks
256
+ def __init__ (self , mock_cache , mock_module ):
257
+ self .__mock_cache = mock_cache
211
258
self .mock_module = mock_module
212
259
213
260
def _start_patch (
@@ -219,7 +266,7 @@ def _start_patch(
219
266
"""
220
267
p = mock_func (* args , ** kwargs )
221
268
mocked : MockType = p .start ()
222
- self .__patches_and_mocks . append (( p , mocked ) )
269
+ self .__mock_cache . add ( mock = mocked , patch = p )
223
270
if hasattr (mocked , "reset_mock" ):
224
271
# check if `mocked` is actually a mock object, as depending on autospec or target
225
272
# parameters `mocked` can be anything
0 commit comments