19
19
20
20
from typing import Callable , Optional , Type , Union
21
21
22
+ from mypy .nodes import ARG_POS , Decorator , MemberExpr
22
23
from mypy .plugin import FunctionContext , MethodContext , MethodSigContext , Plugin
23
24
from mypy .typeops import bind_self
24
25
from mypy .types import AnyType , CallableType , Instance
@@ -46,15 +47,16 @@ class _AdjustArguments(object):
46
47
"""
47
48
48
49
def __call__ (self , ctx : FunctionContext ) -> MypyType :
50
+ defn = ctx .arg_types [0 ][0 ]
49
51
is_defined_by_class = (
50
- isinstance (ctx . arg_types [ 0 ][ 0 ] , CallableType ) and
51
- not ctx . arg_types [ 0 ][ 0 ] .arg_types and
52
- isinstance (ctx . arg_types [ 0 ][ 0 ] .ret_type , Instance )
52
+ isinstance (defn , CallableType ) and
53
+ not defn .arg_types and
54
+ isinstance (defn .ret_type , Instance )
53
55
)
54
56
55
57
if is_defined_by_class :
56
58
return self ._adjust_protocol_arguments (ctx )
57
- elif isinstance (ctx . arg_types [ 0 ][ 0 ] , CallableType ):
59
+ elif isinstance (defn , CallableType ):
58
60
return self ._adjust_function_arguments (ctx )
59
61
return ctx .default_return_type
60
62
@@ -144,12 +146,87 @@ class _AdjustInstanceSignature(object):
144
146
"""
145
147
146
148
def __call__ (self , ctx : MethodContext ) -> MypyType :
149
+ if not isinstance (ctx .type , Instance ):
150
+ return ctx .default_return_type
151
+ if not isinstance (ctx .default_return_type , CallableType ):
152
+ return ctx .default_return_type
153
+
147
154
instance_type = self ._adjust_typeclass_callable (ctx )
148
155
self ._adjust_typeclass_type (ctx , instance_type )
149
156
if isinstance (instance_type , Instance ):
150
157
self ._add_supports_metadata (ctx , instance_type )
151
158
return ctx .default_return_type
152
159
160
+ @classmethod
161
+ def from_function_decorator (cls , ctx : FunctionContext ) -> MypyType :
162
+ """
163
+ It is used when ``.instance`` is used without params as a decorator.
164
+
165
+ Like:
166
+
167
+ .. code:: python
168
+
169
+ @some.instance
170
+ def _some_str(instance: str) -> str:
171
+ ...
172
+
173
+ """
174
+ is_decorator = (
175
+ isinstance (ctx .context , Decorator ) and
176
+ len (ctx .context .decorators ) == 1 and
177
+ isinstance (ctx .context .decorators [0 ], MemberExpr ) and
178
+ ctx .context .decorators [0 ].name == 'instance'
179
+ )
180
+ if not is_decorator :
181
+ return ctx .default_return_type
182
+
183
+ passed_function = ctx .arg_types [0 ][0 ]
184
+ assert isinstance (passed_function , CallableType )
185
+
186
+ if not passed_function .arg_types :
187
+ return ctx .default_return_type
188
+
189
+ annotation_type = passed_function .arg_types [0 ]
190
+ if isinstance (annotation_type , Instance ):
191
+ if annotation_type .type and annotation_type .type .is_protocol :
192
+ ctx .api .fail (
193
+ 'Protocols must be passed with `is_protocol=True`' ,
194
+ ctx .context ,
195
+ )
196
+ return ctx .default_return_type
197
+ else :
198
+ ctx .api .fail (
199
+ 'Only simple instance types are allowed, got: {0}' .format (
200
+ annotation_type ,
201
+ ),
202
+ ctx .context ,
203
+ )
204
+ return ctx .default_return_type
205
+
206
+ ret_type = CallableType (
207
+ arg_types = [passed_function ],
208
+ arg_kinds = [ARG_POS ],
209
+ arg_names = [None ],
210
+ ret_type = AnyType (TypeOfAny .implementation_artifact ),
211
+ fallback = passed_function .fallback ,
212
+ )
213
+ instance_type = ctx .api .expr_checker .accept ( # type: ignore
214
+ ctx .context .decorators [0 ].expr , # type: ignore
215
+ )
216
+
217
+ # We need to change the `ctx` type from `Function` to `Method`:
218
+ return cls ()(MethodContext (
219
+ type = instance_type ,
220
+ arg_types = ctx .arg_types ,
221
+ arg_kinds = ctx .arg_kinds ,
222
+ arg_names = ctx .arg_names ,
223
+ args = ctx .args ,
224
+ callee_arg_names = ctx .callee_arg_names ,
225
+ default_return_type = ret_type ,
226
+ context = ctx .context ,
227
+ api = ctx .api ,
228
+ ))
229
+
153
230
def _adjust_typeclass_callable (
154
231
self ,
155
232
ctx : MethodContext ,
@@ -302,6 +379,9 @@ def get_function_hook(
302
379
"""Here we adjust the typeclass constructor."""
303
380
if fullname == 'classes._typeclass.typeclass' :
304
381
return _AdjustArguments ()
382
+ if fullname == 'instance of _TypeClass' :
383
+ # `@some.instance` call without params:
384
+ return _AdjustInstanceSignature .from_function_decorator
305
385
return None
306
386
307
387
def get_method_hook (
@@ -310,6 +390,7 @@ def get_method_hook(
310
390
) -> Optional [Callable [[MethodContext ], MypyType ]]:
311
391
"""Here we adjust the typeclass with new allowed types."""
312
392
if fullname == 'classes._typeclass._TypeClass.instance' :
393
+ # `@some.instance` call with explicit params:
313
394
return _AdjustInstanceSignature ()
314
395
return None
315
396
0 commit comments