1
1
from __future__ import annotations
2
2
3
- from typing import Final , Iterable , Mapping , Sequence , TypeVar , cast , overload
3
+ from contextlib import contextmanager
4
+ from typing import Final , Generator , Iterable , Mapping , Sequence , TypeVar , cast , overload
4
5
5
- from mypy .nodes import ARG_STAR , FakeInfo , Var
6
+ from mypy .nodes import ARG_STAR , CONTRAVARIANT , COVARIANT , FakeInfo , Var , INVARIANT
6
7
from mypy .state import state
7
8
from mypy .types import (
8
9
ANY_STRATEGY ,
38
39
UninhabitedType ,
39
40
UnionType ,
40
41
UnpackType ,
42
+ VarianceModifier ,
41
43
flatten_nested_unions ,
42
44
get_proper_type ,
43
45
split_with_prefix_and_suffix ,
53
55
54
56
55
57
@overload
56
- def expand_type (typ : CallableType , env : Mapping [TypeVarId , Type ]) -> CallableType : ...
58
+ def expand_type (
59
+ typ : CallableType , env : Mapping [TypeVarId , Type ], * , variance : int | None = ...
60
+ ) -> CallableType : ...
57
61
58
62
59
63
@overload
60
- def expand_type (typ : ProperType , env : Mapping [TypeVarId , Type ]) -> ProperType : ...
64
+ def expand_type (
65
+ typ : ProperType , env : Mapping [TypeVarId , Type ], * , variance : int | None = ...
66
+ ) -> ProperType : ...
61
67
62
68
63
69
@overload
64
- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type : ...
70
+ def expand_type (
71
+ typ : Type , env : Mapping [TypeVarId , Type ], * , variance : int | None = ...
72
+ ) -> Type : ...
65
73
66
74
67
- def expand_type (typ : Type , env : Mapping [TypeVarId , Type ]) -> Type :
75
+ def expand_type (typ : Type , env : Mapping [TypeVarId , Type ], * , variance = None ) -> Type :
68
76
"""Substitute any type variable references in a type given by a type
69
77
environment.
70
78
"""
71
- return typ .accept (ExpandTypeVisitor (env ))
79
+ return typ .accept (ExpandTypeVisitor (env , variance = variance ))
72
80
73
81
74
82
@overload
75
- def expand_type_by_instance (typ : CallableType , instance : Instance ) -> CallableType : ...
83
+ def expand_type_by_instance (
84
+ typ : CallableType , instance : Instance , * , use_variance : int | None = ...
85
+ ) -> CallableType : ...
76
86
77
87
78
88
@overload
79
- def expand_type_by_instance (typ : ProperType , instance : Instance ) -> ProperType : ...
89
+ def expand_type_by_instance (
90
+ typ : ProperType , instance : Instance , * , use_variance : int | None = ...
91
+ ) -> ProperType : ...
80
92
81
93
82
94
@overload
83
- def expand_type_by_instance (typ : Type , instance : Instance ) -> Type : ...
95
+ def expand_type_by_instance (
96
+ typ : Type , instance : Instance , * , use_variance : int | None = ...
97
+ ) -> Type : ...
84
98
85
99
86
- def expand_type_by_instance (typ : Type , instance : Instance ) -> Type :
100
+ def expand_type_by_instance (typ : Type , instance : Instance , use_variance = None ) -> Type :
87
101
"""Substitute type variables in type using values from an Instance.
88
102
Type variables are considered to be bound by the class declaration."""
89
103
if not instance .args and not instance .type .has_type_var_tuple_type :
@@ -108,12 +122,11 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
108
122
else :
109
123
tvars = tuple (instance .type .defn .type_vars )
110
124
instance_args = instance .args
111
-
112
125
for binder , arg in zip (tvars , instance_args ):
113
126
assert isinstance (binder , TypeVarLikeType )
114
127
variables [binder .id ] = arg
115
128
116
- return expand_type (typ , variables )
129
+ return expand_type (typ , variables , variance = use_variance )
117
130
118
131
119
132
F = TypeVar ("F" , bound = FunctionLike )
@@ -181,10 +194,28 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
181
194
182
195
variables : Mapping [TypeVarId , Type ] # TypeVar id -> TypeVar value
183
196
184
- def __init__ (self , variables : Mapping [TypeVarId , Type ]) -> None :
197
+ def __init__ (
198
+ self , variables : Mapping [TypeVarId , Type ], * , variance : int | None = None
199
+ ) -> None :
185
200
super ().__init__ ()
186
201
self .variables = variables
187
202
self .recursive_tvar_guard : dict [TypeVarId , Type | None ] = {}
203
+ self .variance = variance
204
+ self .using_variance : int | None = None
205
+
206
+ @contextmanager
207
+ def in_variance (self ) -> Generator [None ]:
208
+ using_variance = self .using_variance
209
+ self .using_variance = CONTRAVARIANT
210
+ yield
211
+ self .using_variance = using_variance
212
+
213
+ @contextmanager
214
+ def out_variance (self ) -> Generator [None ]:
215
+ using_variance = self .using_variance
216
+ self .using_variance = COVARIANT
217
+ yield
218
+ self .using_variance = using_variance
188
219
189
220
def visit_unbound_type (self , t : UnboundType ) -> Type :
190
221
return t
@@ -238,6 +269,19 @@ def visit_type_var(self, t: TypeVarType) -> Type:
238
269
if t .id .is_self ():
239
270
t = t .copy_modified (upper_bound = t .upper_bound .accept (self ))
240
271
repl = self .variables .get (t .id , t )
272
+ use_site_variance = repl .variance if isinstance (repl , VarianceModifier ) else None
273
+ positional_variance = self .using_variance or self .variance
274
+ if (
275
+ positional_variance is not None
276
+ and use_site_variance is not None
277
+ and use_site_variance is not INVARIANT
278
+ and positional_variance != use_site_variance
279
+ ):
280
+ repl = (
281
+ t .upper_bound .accept (self )
282
+ if positional_variance == COVARIANT
283
+ else UninhabitedType ()
284
+ )
241
285
if isinstance (repl , ProperType ) and isinstance (repl , Instance ):
242
286
# TODO: do we really need to do this?
243
287
# If I try to remove this special-casing ~40 tests fail on reveal_type().
@@ -414,10 +458,15 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
414
458
needs_normalization = True
415
459
arg_types = self .interpolate_args_for_unpack (t , var_arg .typ )
416
460
else :
417
- arg_types = self .expand_types (t .arg_types )
461
+ with self .in_variance ():
462
+ arg_types = self .expand_types (t .arg_types )
463
+ with self .out_variance ():
464
+ ret_type = t .ret_type .accept (self )
465
+ if isinstance (ret_type , VarianceModifier ):
466
+ ret_type = ret_type .value
418
467
expanded = t .copy_modified (
419
468
arg_types = arg_types ,
420
- ret_type = t . ret_type . accept ( self ) ,
469
+ ret_type = ret_type ,
421
470
type_guard = t .type_guard and cast (TypeGuardType , t .type_guard .accept (self )),
422
471
type_is = (t .type_is .accept (self ) if t .type_is is not None else None ),
423
472
)
@@ -538,7 +587,10 @@ def visit_typeguard_type(self, t: TypeGuardType) -> Type:
538
587
def expand_types (self , types : Iterable [Type ]) -> list [Type ]:
539
588
a : list [Type ] = []
540
589
for t in types :
541
- a .append (t .accept (self ))
590
+ typ = t .accept (self )
591
+ if isinstance (typ , VarianceModifier ):
592
+ typ = typ .value
593
+ a .append (typ )
542
594
return a
543
595
544
596
0 commit comments