Skip to content

Commit d7ca220

Browse files
committed
variance modifiers
1 parent 8cd497f commit d7ca220

17 files changed

+398
-41
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Basedmypy Changelog
22

33
## [Unreleased]
4+
### Added
5+
- explicit and use-site variance modifiers `In`/`Out`/`InOut`
46

57
## [2.9.0]
68
### Added

docs/source/based_features.rst

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,78 @@ Using the ``&`` operator or ``basedtyping.Intersection`` you can denote intersec
2525
x.reset()
2626
x.add("first")
2727
28+
29+
Explicit Variance
30+
-----------------
31+
32+
it is frequently desirable to explicitly declare the variance of type parameters on types and classes.
33+
but until dedicated syntax is added:
34+
35+
.. code-block:: python
36+
37+
from basedtyping import In, InOut, Out
38+
39+
class Example[
40+
Contravariant: In, # In designates contravariant, as values can only pass 'into' the class
41+
Invariant: InOut, # I nOut designates invariant, as values can pass both 'into' and 'out' of the class
42+
Covariant: Out, # Out designates covariant, as the values can only pass 'out' of the class
43+
]: ...
44+
45+
The same applies to type declarations:
46+
47+
.. code-block:: python
48+
49+
type Example[Contravariant: In, Invariant: InOut, Covariant: Out] = ...
50+
51+
when a bound is supplied, it is provided as an argument to the variance modifier:
52+
53+
.. code-block:: python
54+
55+
class Example[T: Out[int]]: ...
56+
57+
58+
Use-site Variance
59+
-----------------
60+
61+
use-site variance is a concept that can be used to modify an invariant type
62+
parameter to be modified as covariant or contravariant
63+
64+
given:
65+
66+
.. code-block:: python
67+
68+
def f(data: list[object]): # we can't use `Sequence[object]` because we need `clear`
69+
for element in data:
70+
print(element)
71+
data.clear()
72+
73+
a = [1, 2, 3]
74+
f(a) # error: list[int] is incompatible with list[object]
75+
76+
we can implement use-site variance here to make the api both type-safe and ergonomic:
77+
78+
.. code-block:: python
79+
80+
def f(data: list[Out[object]]):
81+
for element in data:
82+
print(element)
83+
data.clear()
84+
85+
a = [1, 2, 3]
86+
f(a) # no error, list[int] is a valid subtype of the covariant list[out object]
87+
88+
what makes this typesafe is that the usages of the type parameter in input positions
89+
are replaced with `Never` (or output positions and the upper bound in the case of contravariance):
90+
91+
.. code-block:: python
92+
93+
class A[T: int | str]:
94+
def f(self, t: T) -> T: ...
95+
96+
A[Out[int]]().f # (t: Never) -> int
97+
A[In[int]]().f # (t: int) -> int | str
98+
99+
28100
Type Joins
29101
----------
30102

mypy/checkmember.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
ARG_POS,
1818
ARG_STAR,
1919
ARG_STAR2,
20+
CONTRAVARIANT,
21+
COVARIANT,
2022
EXCLUDED_ENUM_ATTRIBUTES,
2123
SYMBOL_FUNCBASE_TYPES,
2224
Context,
@@ -811,7 +813,9 @@ def analyze_var(
811813
mx.msg.cant_assign_to_classvar(name, mx.context)
812814
t = freshen_all_functions_type_vars(typ)
813815
t = expand_self_type_if_needed(t, mx, var, original_itype)
814-
t = expand_type_by_instance(t, itype)
816+
t = expand_type_by_instance(
817+
t, itype, use_variance=CONTRAVARIANT if mx.is_lvalue else COVARIANT
818+
)
815819
freeze_all_type_vars(t)
816820
result = t
817821
typ = get_proper_type(typ)

mypy/expandtype.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload
3+
from contextlib import contextmanager
4+
from typing import Final, Generator, Iterable, Mapping, Sequence, TypeVar, cast, overload
45

5-
from mypy.nodes import ARG_STAR, FakeInfo, Var
6+
from mypy.nodes import ARG_STAR, CONTRAVARIANT, COVARIANT, FakeInfo, Var, INVARIANT
67
from mypy.state import state
78
from mypy.types import (
89
ANY_STRATEGY,
@@ -38,6 +39,7 @@
3839
UninhabitedType,
3940
UnionType,
4041
UnpackType,
42+
VarianceModifier,
4143
flatten_nested_unions,
4244
get_proper_type,
4345
split_with_prefix_and_suffix,
@@ -53,37 +55,49 @@
5355

5456

5557
@overload
56-
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ...
58+
def expand_type(
59+
typ: CallableType, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
60+
) -> CallableType: ...
5761

5862

5963
@overload
60-
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ...
64+
def expand_type(
65+
typ: ProperType, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
66+
) -> ProperType: ...
6167

6268

6369
@overload
64-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ...
70+
def expand_type(
71+
typ: Type, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
72+
) -> Type: ...
6573

6674

67-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
75+
def expand_type(typ: Type, env: Mapping[TypeVarId, Type], *, variance=None) -> Type:
6876
"""Substitute any type variable references in a type given by a type
6977
environment.
7078
"""
71-
return typ.accept(ExpandTypeVisitor(env))
79+
return typ.accept(ExpandTypeVisitor(env, variance=variance))
7280

7381

7482
@overload
75-
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: ...
83+
def expand_type_by_instance(
84+
typ: CallableType, instance: Instance, *, use_variance: int | None = ...
85+
) -> CallableType: ...
7686

7787

7888
@overload
79-
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ...
89+
def expand_type_by_instance(
90+
typ: ProperType, instance: Instance, *, use_variance: int | None = ...
91+
) -> ProperType: ...
8092

8193

8294
@overload
83-
def expand_type_by_instance(typ: Type, instance: Instance) -> Type: ...
95+
def expand_type_by_instance(
96+
typ: Type, instance: Instance, *, use_variance: int | None = ...
97+
) -> Type: ...
8498

8599

86-
def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
100+
def expand_type_by_instance(typ: Type, instance: Instance, use_variance=None) -> Type:
87101
"""Substitute type variables in type using values from an Instance.
88102
Type variables are considered to be bound by the class declaration."""
89103
if not instance.args and not instance.type.has_type_var_tuple_type:
@@ -108,12 +122,11 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
108122
else:
109123
tvars = tuple(instance.type.defn.type_vars)
110124
instance_args = instance.args
111-
112125
for binder, arg in zip(tvars, instance_args):
113126
assert isinstance(binder, TypeVarLikeType)
114127
variables[binder.id] = arg
115128

116-
return expand_type(typ, variables)
129+
return expand_type(typ, variables, variance=use_variance)
117130

118131

119132
F = TypeVar("F", bound=FunctionLike)
@@ -181,10 +194,28 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
181194

182195
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
183196

184-
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
197+
def __init__(
198+
self, variables: Mapping[TypeVarId, Type], *, variance: int | None = None
199+
) -> None:
185200
super().__init__()
186201
self.variables = variables
187202
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
203+
self.variance = variance
204+
self.using_variance: int | None = None
205+
206+
@contextmanager
207+
def in_variance(self) -> Generator[None]:
208+
using_variance = self.using_variance
209+
self.using_variance = CONTRAVARIANT
210+
yield
211+
self.using_variance = using_variance
212+
213+
@contextmanager
214+
def out_variance(self) -> Generator[None]:
215+
using_variance = self.using_variance
216+
self.using_variance = COVARIANT
217+
yield
218+
self.using_variance = using_variance
188219

189220
def visit_unbound_type(self, t: UnboundType) -> Type:
190221
return t
@@ -238,6 +269,19 @@ def visit_type_var(self, t: TypeVarType) -> Type:
238269
if t.id.is_self():
239270
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
240271
repl = self.variables.get(t.id, t)
272+
use_site_variance = repl.variance if isinstance(repl, VarianceModifier) else None
273+
positional_variance = self.using_variance or self.variance
274+
if (
275+
positional_variance is not None
276+
and use_site_variance is not None
277+
and use_site_variance is not INVARIANT
278+
and positional_variance != use_site_variance
279+
):
280+
repl = (
281+
t.upper_bound.accept(self)
282+
if positional_variance == COVARIANT
283+
else UninhabitedType()
284+
)
241285
if isinstance(repl, ProperType) and isinstance(repl, Instance):
242286
# TODO: do we really need to do this?
243287
# If I try to remove this special-casing ~40 tests fail on reveal_type().
@@ -414,10 +458,15 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
414458
needs_normalization = True
415459
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
416460
else:
417-
arg_types = self.expand_types(t.arg_types)
461+
with self.in_variance():
462+
arg_types = self.expand_types(t.arg_types)
463+
with self.out_variance():
464+
ret_type = t.ret_type.accept(self)
465+
if isinstance(ret_type, VarianceModifier):
466+
ret_type = ret_type.value
418467
expanded = t.copy_modified(
419468
arg_types=arg_types,
420-
ret_type=t.ret_type.accept(self),
469+
ret_type=ret_type,
421470
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
422471
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
423472
)
@@ -538,7 +587,10 @@ def visit_typeguard_type(self, t: TypeGuardType) -> Type:
538587
def expand_types(self, types: Iterable[Type]) -> list[Type]:
539588
a: list[Type] = []
540589
for t in types:
541-
a.append(t.accept(self))
590+
typ = t.accept(self)
591+
if isinstance(typ, VarianceModifier):
592+
typ = typ.value
593+
a.append(typ)
542594
return a
543595

544596

mypy/message_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
113113
)
114114
FORMAT_REQUIRES_MAPPING: Final = "Format requires a mapping"
115115
RETURN_TYPE_CANNOT_BE_CONTRAVARIANT: Final = ErrorMessage(
116-
"This usage of this contravariant type variable is unsafe as a return type.",
116+
"This usage of this contravariant type variable is unsafe as a return type",
117117
codes.UNSAFE_VARIANCE,
118118
)
119119
FUNCTION_PARAMETER_CANNOT_BE_COVARIANT: Final = ErrorMessage(
120-
"This usage of this covariant type variable is unsafe as an input parameter.",
120+
"This usage of this covariant type variable is unsafe as an input parameter",
121121
codes.UNSAFE_VARIANCE,
122122
)
123123
UNSAFE_VARIANCE_NOTE = ErrorMessage(

mypy/messages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
UninhabitedType,
9494
UnionType,
9595
UnpackType,
96+
VarianceModifier,
9697
flatten_nested_unions,
9798
get_proper_type,
9899
get_proper_types,
@@ -2676,6 +2677,9 @@ def format_literal_value(typ: LiteralType) -> str:
26762677
type_str += f"[{format_list(typ.args)}]"
26772678
return type_str
26782679

2680+
if isinstance(typ, VarianceModifier):
2681+
return typ.render(format)
2682+
26792683
# TODO: always mention type alias names in errors.
26802684
typ = get_proper_type(typ)
26812685

mypy/plugins/proper_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def is_special_target(right: ProperType) -> bool:
107107
"mypy.types.DeletedType",
108108
"mypy.types.RequiredType",
109109
"mypy.types.ReadOnlyType",
110+
"mypy.types.VarianceModifier",
110111
):
111112
# Special case: these are not valid targets for a type alias and thus safe.
112113
# TODO: introduce a SyntheticType base to simplify this?

0 commit comments

Comments
 (0)