Skip to content

Commit 477a037

Browse files
committed
narrow to literal value
1 parent c1cbfa3 commit 477a037

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+630
-404
lines changed

.idea/watcherTasks.xml

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mypy/binder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import DefaultDict, Generator, Iterator, List, NamedTuple, Optional, Tuple, Union
66
from typing_extensions import TypeAlias as _TypeAlias
77

8-
from mypy.erasetype import remove_instance_last_known_values
98
from mypy.join import join_simple
109
from mypy.literals import Key, literal, literal_hash, subkeys
1110
from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var
@@ -331,7 +330,8 @@ def assign_type(
331330
) -> None:
332331
# We should erase last known value in binder, because if we are using it,
333332
# it means that the target is not final, and therefore can't hold a literal.
334-
type = remove_instance_last_known_values(type)
333+
# HUUHHH?????
334+
# type = remove_instance_last_known_values(type)
335335

336336
if self.type_assignments is not None:
337337
# We are in a multiassign from union, defer the actual binding,

mypy/checker.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,19 +3589,30 @@ def check_assignment(
35893589
):
35903590
lvalue.node.type = remove_instance_last_known_values(lvalue_type)
35913591

3592+
elif lvalue.node and lvalue.node.is_inferred and rvalue_type:
3593+
# for literal values
3594+
# Don't use type binder for definitions of special forms, like named tuples.
3595+
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
3596+
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)
3597+
35923598
elif index_lvalue:
35933599
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
35943600

35953601
if inferred:
35963602
type_context = self.get_variable_type_context(inferred)
35973603
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
3604+
original_rvalue_type = rvalue_type
35983605
if not (
35993606
inferred.is_final
36003607
or inferred.is_index_var
36013608
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
36023609
):
36033610
rvalue_type = remove_instance_last_known_values(rvalue_type)
3604-
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
3611+
if self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue) or self.binder.type_assignments:
3612+
# we don't always want to assign the type here as it might be something like partial
3613+
self.binder.assign_type(
3614+
lvalue, original_rvalue_type, original_rvalue_type, False
3615+
)
36053616
self.check_assignment_to_slots(lvalue)
36063617

36073618
# (type, operator) tuples for augmented assignments supported with partial types
@@ -4553,12 +4564,13 @@ def is_definition(self, s: Lvalue) -> bool:
45534564

45544565
def infer_variable_type(
45554566
self, name: Var, lvalue: Lvalue, init_type: Type, context: Context
4556-
) -> None:
4567+
) -> bool:
45574568
"""Infer the type of initialized variables from initializer type."""
4569+
valid = True
45584570
if isinstance(init_type, DeletedType):
45594571
self.msg.deleted_as_rvalue(init_type, context)
45604572
elif (
4561-
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
4573+
not (valid := is_valid_inferred_type(init_type, is_lvalue_final=name.is_final))
45624574
and not self.no_partial_types
45634575
):
45644576
# We cannot use the type of the initialization expression for full type
@@ -4585,6 +4597,7 @@ def infer_variable_type(
45854597
init_type = strip_type(init_type)
45864598

45874599
self.set_inferred_type(name, lvalue, init_type)
4600+
return valid
45884601

45894602
def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool:
45904603
init_type = get_proper_type(init_type)

mypy/checkexpr.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@
210210
# Type of callback user for checking individual function arguments. See
211211
# check_args() below for details.
212212
ArgChecker: _TypeAlias = Callable[
213-
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context], None
213+
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context, bool],
214+
None,
214215
]
215216

216217
# Maximum nesting level for math union in overloads, setting this to large values
@@ -2175,6 +2176,13 @@ def infer_function_type_arguments(
21752176
Return a derived callable type that has the arguments applied.
21762177
"""
21772178
if self.chk.in_checked_function():
2179+
if isinstance(callee_type.ret_type, TypeVarType):
2180+
# if the return type is constant, infer as literal
2181+
rvalue_type = [
2182+
remove_instance_last_known_values(arg) if isinstance(arg, Instance) else arg
2183+
for arg in args
2184+
]
2185+
21782186
# Disable type errors during type inference. There may be errors
21792187
# due to partial available context information at this time, but
21802188
# these errors can be safely ignored as the arguments will be
@@ -2581,6 +2589,8 @@ def check_argument_types(
25812589
context: Context,
25822590
check_arg: ArgChecker | None = None,
25832591
object_type: Type | None = None,
2592+
*,
2593+
type_function=False,
25842594
) -> None:
25852595
"""Check argument types against a callable type.
25862596
@@ -2712,6 +2722,7 @@ def check_argument_types(
27122722
object_type,
27132723
args[actual],
27142724
context,
2725+
type_function,
27152726
)
27162727

27172728
def check_arg(
@@ -2726,12 +2737,16 @@ def check_arg(
27262737
object_type: Type | None,
27272738
context: Context,
27282739
outer_context: Context,
2740+
type_function=False,
27292741
) -> None:
27302742
"""Check the type of a single argument in a call."""
27312743
caller_type = get_proper_type(caller_type)
27322744
original_caller_type = get_proper_type(original_caller_type)
27332745
callee_type = get_proper_type(callee_type)
2734-
2746+
if type_function:
2747+
# TODO: make this work at all
2748+
if not isinstance(caller_type, Instance) or not caller_type.last_known_value:
2749+
caller_type = self.named_type("builtins.object")
27352750
if isinstance(caller_type, DeletedType):
27362751
self.msg.deleted_as_rvalue(caller_type, context)
27372752
# Only non-abstract non-protocol class can be given where Type[...] is expected...
@@ -3348,6 +3363,7 @@ def check_arg(
33483363
object_type: Type | None,
33493364
context: Context,
33503365
outer_context: Context,
3366+
type_function: bool,
33513367
) -> None:
33523368
if not arg_approximate_similarity(caller_type, callee_type):
33533369
# No match -- exit early since none of the remaining work can change
@@ -3580,10 +3596,14 @@ def visit_bytes_expr(self, e: BytesExpr) -> Type:
35803596

35813597
def visit_float_expr(self, e: FloatExpr) -> Type:
35823598
"""Type check a float literal (trivial)."""
3599+
if mypy.options._based:
3600+
return self.infer_literal_expr_type(e.value, "builtins.float")
35833601
return self.named_type("builtins.float")
35843602

35853603
def visit_complex_expr(self, e: ComplexExpr) -> Type:
35863604
"""Type check a complex literal."""
3605+
if mypy.options._based:
3606+
return self.infer_literal_expr_type(e.value, "builtins.complex")
35873607
return self.named_type("builtins.complex")
35883608

35893609
def visit_ellipsis(self, e: EllipsisExpr) -> Type:

mypy/expandtype.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from contextlib import contextmanager
34
from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload
45

56
from mypy.nodes import ARG_STAR, FakeInfo, Var
@@ -185,6 +186,16 @@ def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
185186
super().__init__()
186187
self.variables = variables
187188
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
189+
self._erase_literals = False
190+
191+
@contextmanager
192+
def erase_literals(self):
193+
_erase_literals = self._erase_literals
194+
self._erase_literals = True
195+
try:
196+
yield
197+
finally:
198+
self._erase_literals = _erase_literals
188199

189200
def visit_unbound_type(self, t: UnboundType) -> Type:
190201
return t
@@ -211,7 +222,8 @@ def visit_erased_type(self, t: ErasedType) -> Type:
211222
return t
212223

213224
def visit_instance(self, t: Instance) -> Type:
214-
args = self.expand_types_with_unpack(list(t.args))
225+
with self.erase_literals():
226+
args = self.expand_types_with_unpack(list(t.args))
215227

216228
if isinstance(t.type, FakeInfo):
217229
# The type checker expands function definitions and bodies
@@ -238,7 +250,7 @@ def visit_type_var(self, t: TypeVarType) -> Type:
238250
if t.id.is_self():
239251
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
240252
repl = self.variables.get(t.id, t)
241-
if isinstance(repl, ProperType) and isinstance(repl, Instance):
253+
if self._erase_literals and isinstance(repl, ProperType) and isinstance(repl, Instance):
242254
# TODO: do we really need to do this?
243255
# If I try to remove this special-casing ~40 tests fail on reveal_type().
244256
return repl.copy_modified(last_known_value=None)
@@ -410,17 +422,18 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
410422

411423
var_arg = t.var_arg()
412424
needs_normalization = False
413-
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
414-
needs_normalization = True
415-
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
416-
else:
417-
arg_types = self.expand_types(t.arg_types)
418-
expanded = t.copy_modified(
419-
arg_types=arg_types,
420-
ret_type=t.ret_type.accept(self),
421-
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
422-
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
423-
)
425+
with self.erase_literals():
426+
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
427+
needs_normalization = True
428+
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
429+
else:
430+
arg_types = self.expand_types(t.arg_types)
431+
expanded = t.copy_modified(
432+
arg_types=arg_types,
433+
ret_type=t.ret_type.accept(self),
434+
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
435+
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
436+
)
424437
if needs_normalization:
425438
return expanded.with_normalized_var_args()
426439
return expanded
@@ -467,7 +480,9 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
467480
return cached
468481
fallback = t.fallback.accept(self)
469482
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
470-
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
483+
with self.erase_literals():
484+
# TODO: we don't want to erase literals for `ReadOnly` keys
485+
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
471486
self.set_cached(t, result)
472487
return result
473488

mypy/fastparse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2129,7 +2129,8 @@ def numeric_type(self, value: object, n: AST) -> Type:
21292129
# Other kinds of numbers (floats, complex) are not valid parameters for
21302130
# RawExpressionType so we just pass in 'None' for now. We'll report the
21312131
# appropriate error at a later stage.
2132-
numeric_value = None
2132+
# based: they are valid
2133+
numeric_value = value
21332134
type_name = f"builtins.{type(value).__name__}"
21342135
return RawExpressionType(
21352136
numeric_value, type_name, line=self.line, column=getattr(n, "col_offset", -1)

mypy/join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
232232
if declaration is None or is_subtype(value, declaration):
233233
return value
234234

235-
return declaration
235+
return value
236236

237237

238238
def trivial_join(s: Type, t: Type) -> Type:

mypy/meet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,12 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
197197
# Special case: 'int' can't be narrowed down to a native int type such as
198198
# i64, since they have different runtime representations.
199199
return original_declared
200+
if isinstance(narrowed, Instance) and narrowed.last_known_value:
201+
return narrowed
200202
return meet_types(original_declared, original_narrowed)
201203
elif isinstance(declared, (TupleType, TypeType, LiteralType)):
202-
return meet_types(original_declared, original_narrowed)
204+
# this way around to preserve the last know of the items in the tuple
205+
return meet_types(original_narrowed, original_declared, intersect=True)
203206
elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance):
204207
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
205208
if narrowed.type.fullname == "builtins.dict" and all(

mypy/test/testtypes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,8 +694,14 @@ def test_simplified_intersection(self):
694694
[fx.b, IntersectionType([fx.c, IntersectionType([fx.d])])],
695695
IntersectionType([fx.b, fx.c, fx.d]),
696696
)
697+
self.assert_simplified_intersection([fx.bool_type, fx.lit_true], fx.lit_true)
698+
699+
# special case: it's not currently symmetric when there are last known values
700+
narrowed = fx.bool_type.copy_modified(last_known_value=fx.lit_true)
701+
assert_equal(make_simplified_intersection([narrowed, fx.bool_type]), narrowed)
697702

698703
def assert_simplified_intersection(self, original: list[Type], intersection: Type) -> None:
704+
__tracebackhide__ = True
699705
assert_equal(make_simplified_intersection(original), intersection)
700706
assert_equal(make_simplified_intersection(list(reversed(original))), intersection)
701707

mypy/typeanal.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
527527
return res
528528
elif isinstance(node, TypeInfo):
529529
return self.analyze_type_with_type_info(node, t.args, t, t.empty_tuple_index)
530-
531530
elif node.fullname in TYPE_ALIAS_NAMES:
532531
return AnyType(TypeOfAny.special_form)
533532
# Concatenate is an operator, no need for a proper type
@@ -1432,7 +1431,11 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type:
14321431

14331432
if self.report_invalid_types:
14341433
msg = None
1435-
if t.base_type_name in ("builtins.int", "builtins.bool"):
1434+
if (
1435+
t.base_type_name in ("builtins.int", "builtins.bool")
1436+
or mypy.options._based
1437+
and t.base_type_name in ("builtins.float", "builtins.complex")
1438+
):
14361439
if not self.options.bare_literals:
14371440
# The only time it makes sense to use an int or bool is inside of
14381441
# a literal type.
@@ -1453,7 +1456,12 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type:
14531456
self.fail(msg, t, code=codes.VALID_TYPE)
14541457
if t.note is not None:
14551458
self.note(t.note, t, code=codes.VALID_TYPE)
1456-
if t.base_type_name in ("builtins.int", "builtins.bool"):
1459+
if t.base_type_name in (
1460+
"builtins.int",
1461+
"builtins.bool",
1462+
"builtins.float",
1463+
"builtins.complex",
1464+
):
14571465
v = t.literal_value
14581466
assert v is not None
14591467
result = LiteralType(

mypy/typeops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,8 @@ def _remove_redundant_intersection_items(items: list[Type], keep_erased: bool) -
724724
if inner_i in removed:
725725
continue
726726
proper_inner = get_proper_type(items[inner_i])
727+
# hacky: we check this one first, because it's more likely that the value on the left
728+
# has a last known value/metadata/extra args
727729
if is_proper_subtype(
728730
proper_outer, proper_inner, keep_erased_types=keep_erased, ignore_promotions=True
729731
):

mypy/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
#
7676
# Note: Float values are only used internally. They are not accepted within
7777
# Literal[...].
78-
LiteralValue: _TypeAlias = Union[int, str, bool, float]
78+
LiteralValue: _TypeAlias = Union[int, str, bool, float, complex]
7979

8080

8181
# If we only import type_visitor in the middle of the file, mypy
@@ -3137,14 +3137,18 @@ def value_repr(self) -> str:
31373137
def serialize(self) -> JsonDict | str:
31383138
return {
31393139
".class": "LiteralType",
3140-
"value": self.value,
3140+
"value": self.value if not isinstance(self.value, complex) else str(self.value),
31413141
"fallback": self.fallback.serialize(),
31423142
}
31433143

31443144
@classmethod
31453145
def deserialize(cls, data: JsonDict) -> LiteralType:
31463146
assert data[".class"] == "LiteralType"
3147-
return LiteralType(value=data["value"], fallback=Instance.deserialize(data["fallback"]))
3147+
fallback = Instance.deserialize(data["fallback"])
3148+
value = data["value"]
3149+
if fallback.type_ref == "builtins.complex":
3150+
value = complex(value)
3151+
return LiteralType(value=value, fallback=fallback)
31483152

31493153
def is_singleton_type(self) -> bool:
31503154
return self.is_enum_literal() or isinstance(self.value, bool)

mypy/typeshed/stdlib/operator.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,16 @@ if sys.version_info >= (3, 11):
183183
@final
184184
class attrgetter(Generic[_T_co]):
185185
@overload
186-
def __new__(cls, attr: str, /) -> attrgetter[Any]: ...
186+
def __new__(cls, attr: str, /) -> attrgetter[object]: ...
187187
@overload
188-
def __new__(cls, attr: str, attr2: str, /) -> attrgetter[tuple[Any, Any]]: ...
188+
def __new__(cls, attr: str, attr2: str, /) -> attrgetter[tuple[object, object]]: ...
189189
@overload
190-
def __new__(cls, attr: str, attr2: str, attr3: str, /) -> attrgetter[tuple[Any, Any, Any]]: ...
190+
def __new__(cls, attr: str, attr2: str, attr3: str, /) -> attrgetter[tuple[object, object, object]]: ...
191191
@overload
192-
def __new__(cls, attr: str, attr2: str, attr3: str, attr4: str, /) -> attrgetter[tuple[Any, Any, Any, Any]]: ...
192+
def __new__(cls, attr: str, attr2: str, attr3: str, attr4: str, /) -> attrgetter[tuple[object, object, object, object]]: ...
193193
@overload
194-
def __new__(cls, attr: str, /, *attrs: str) -> attrgetter[tuple[Any, ...]]: ...
195-
def __call__(self, obj: Any, /) -> _T_co: ...
194+
def __new__(cls, attr: str, /, *attrs: str) -> attrgetter[tuple[object, ...]]: ...
195+
def __call__(self, obj: object, /) -> _T_co: ...
196196

197197
@final
198198
class itemgetter(Generic[_T_co]):

0 commit comments

Comments
 (0)