Skip to content

Commit e71f3ed

Browse files
MatthewMckee4carljmAlexWaygood
authored
[red-knot] Update == and != narrowing (#17567)
## Summary Historically we have avoided narrowing on `==` tests because in many cases it's unsound, since subclasses of a type could compare equal to who-knows-what. But there are a lot of types (literals and unions of them, as well as some known instances like `None` -- single-valued types) whose `__eq__` behavior we know, and which we can safely narrow away based on equality comparisons. This PR implements equality narrowing in the cases where it is sound. The most elegant way to do this (and the way that is most in-line with our approach up until now) would be to introduce new Type variants `NeverEqualTo[...]` and `AlwaysEqualTo[...]`, and then implement all type relations for those variants, narrow by intersection, and let union and intersection simplification sort it all out. This is analogous to our existing handling for `AlwaysFalse` and `AlwaysTrue`. But I'm reluctant to add new `Type` variants for this, mostly because they could end up un-simplified in some types and make types even more complex. So let's try this approach, where we handle more of the narrowing logic as a special case. ## Test Plan Updated and added tests. --------- Co-authored-by: Carl Meyer <[email protected]> Co-authored-by: Carl Meyer <[email protected]> Co-authored-by: Alex Waygood <[email protected]>
1 parent ac6219e commit e71f3ed

File tree

6 files changed

+220
-40
lines changed

6 files changed

+220
-40
lines changed

crates/red_knot_python_semantic/resources/mdtest/narrow/assert.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
2929
assert x is 2
3030
reveal_type(x) # revealed: Literal[2]
3131
assert y == 2
32-
reveal_type(y) # revealed: Literal[1, 2, 3]
32+
reveal_type(y) # revealed: Literal[2]
3333
```
3434

3535
## `assert` with `isinstance`

crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/elif_else.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@ def _(flag1: bool, flag2: bool):
2020
x = 1 if flag1 else 2 if flag2 else 3
2121

2222
if x == 1:
23-
# TODO should be Literal[1]
24-
reveal_type(x) # revealed: Literal[1, 2, 3]
23+
reveal_type(x) # revealed: Literal[1]
2524
elif x == 2:
26-
# TODO should be Literal[2]
27-
reveal_type(x) # revealed: Literal[2, 3]
25+
reveal_type(x) # revealed: Literal[2]
2826
else:
2927
reveal_type(x) # revealed: Literal[3]
3028
```
@@ -38,14 +36,11 @@ def _(flag1: bool, flag2: bool):
3836
if x != 1:
3937
reveal_type(x) # revealed: Literal[2, 3]
4038
elif x != 2:
41-
# TODO should be `Literal[1]`
42-
reveal_type(x) # revealed: Literal[1, 3]
39+
reveal_type(x) # revealed: Literal[1]
4340
elif x == 3:
44-
# TODO should be Never
45-
reveal_type(x) # revealed: Literal[1, 2, 3]
41+
reveal_type(x) # revealed: Never
4642
else:
47-
# TODO should be Never
48-
reveal_type(x) # revealed: Literal[1, 2]
43+
reveal_type(x) # revealed: Never
4944
```
5045

5146
## Assignment expressions

crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/not_eq.md renamed to crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/eq.md

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ def _(flag: bool):
99
if x != None:
1010
reveal_type(x) # revealed: Literal[1]
1111
else:
12-
# TODO should be None
13-
reveal_type(x) # revealed: None | Literal[1]
12+
reveal_type(x) # revealed: None
1413
```
1514

1615
## `!=` for other singleton types
@@ -22,8 +21,7 @@ def _(flag: bool):
2221
if x != False:
2322
reveal_type(x) # revealed: Literal[True]
2423
else:
25-
# TODO should be Literal[False]
26-
reveal_type(x) # revealed: bool
24+
reveal_type(x) # revealed: Literal[False]
2725
```
2826

2927
## `x != y` where `y` is of literal type
@@ -47,8 +45,7 @@ def _(flag: bool):
4745
if C != A:
4846
reveal_type(C) # revealed: Literal[B]
4947
else:
50-
# TODO should be Literal[A]
51-
reveal_type(C) # revealed: Literal[A, B]
48+
reveal_type(C) # revealed: Literal[A]
5249
```
5350

5451
## `x != y` where `y` has multiple single-valued options
@@ -61,8 +58,7 @@ def _(flag1: bool, flag2: bool):
6158
if x != y:
6259
reveal_type(x) # revealed: Literal[1, 2]
6360
else:
64-
# TODO should be Literal[2]
65-
reveal_type(x) # revealed: Literal[1, 2]
61+
reveal_type(x) # revealed: Literal[2]
6662
```
6763

6864
## `!=` for non-single-valued types
@@ -101,6 +97,61 @@ def f() -> Literal[1, 2, 3]:
10197
if (x := f()) != 1:
10298
reveal_type(x) # revealed: Literal[2, 3]
10399
else:
104-
# TODO should be Literal[1]
105-
reveal_type(x) # revealed: Literal[1, 2, 3]
100+
reveal_type(x) # revealed: Literal[1]
101+
```
102+
103+
## Union with `Any`
104+
105+
```py
106+
from typing import Any
107+
108+
def _(x: Any | None, y: Any | None):
109+
if x != 1:
110+
reveal_type(x) # revealed: (Any & ~Literal[1]) | None
111+
if y == 1:
112+
reveal_type(y) # revealed: Any & ~None
113+
```
114+
115+
## Booleans and integers
116+
117+
```py
118+
from typing import Literal
119+
120+
def _(b: bool, i: Literal[1, 2]):
121+
if b == 1:
122+
reveal_type(b) # revealed: Literal[True]
123+
else:
124+
reveal_type(b) # revealed: Literal[False]
125+
126+
if b == 6:
127+
reveal_type(b) # revealed: Never
128+
else:
129+
reveal_type(b) # revealed: bool
130+
131+
if b == 0:
132+
reveal_type(b) # revealed: Literal[False]
133+
else:
134+
reveal_type(b) # revealed: Literal[True]
135+
136+
if i == True:
137+
reveal_type(i) # revealed: Literal[1]
138+
else:
139+
reveal_type(i) # revealed: Literal[2]
140+
```
141+
142+
## Narrowing `LiteralString` in union
143+
144+
```py
145+
from typing_extensions import Literal, LiteralString, Any
146+
147+
def _(s: LiteralString | None, t: LiteralString | Any):
148+
if s == "foo":
149+
reveal_type(s) # revealed: Literal["foo"]
150+
151+
if s == 1:
152+
reveal_type(s) # revealed: Never
153+
154+
if t == "foo":
155+
# TODO could be `Literal["foo"] | Any`
156+
reveal_type(t) # revealed: LiteralString | Any
106157
```

crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals/nested.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,14 @@ def _(flag1: bool, flag2: bool):
3131
if x != 1:
3232
reveal_type(x) # revealed: Literal[2, 3]
3333
if x == 2:
34-
# TODO should be `Literal[2]`
35-
reveal_type(x) # revealed: Literal[2, 3]
34+
reveal_type(x) # revealed: Literal[2]
3635
elif x == 3:
3736
reveal_type(x) # revealed: Literal[3]
3837
else:
3938
reveal_type(x) # revealed: Never
4039

4140
elif x != 2:
42-
# TODO should be Literal[1]
43-
reveal_type(x) # revealed: Literal[1, 3]
41+
reveal_type(x) # revealed: Literal[1]
4442
else:
45-
# TODO should be Never
46-
reveal_type(x) # revealed: Literal[1, 2, 3]
43+
reveal_type(x) # revealed: Never
4744
```

crates/red_knot_python_semantic/src/types.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,11 @@ impl<'db> Type<'db> {
542542
.is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType))
543543
}
544544

545+
fn is_bool(&self, db: &'db dyn Db) -> bool {
546+
self.into_instance()
547+
.is_some_and(|instance| instance.class().is_known(db, KnownClass::Bool))
548+
}
549+
545550
pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool {
546551
self.into_instance().is_some_and(|instance| {
547552
instance
@@ -776,8 +781,13 @@ impl<'db> Type<'db> {
776781
}
777782

778783
pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool {
779-
self.into_union()
780-
.is_some_and(|union| union.elements(db).iter().all(|ty| ty.is_single_valued(db)))
784+
self.into_union().is_some_and(|union| {
785+
union
786+
.elements(db)
787+
.iter()
788+
.all(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
789+
}) || self.is_bool(db)
790+
|| self.is_literal_string()
781791
}
782792

783793
pub const fn into_int_literal(self) -> Option<i64> {

crates/red_knot_python_semantic/src/types/narrow.rs

Lines changed: 138 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,142 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
394394
}
395395
}
396396

397+
fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
398+
// We can only narrow on equality checks against single-valued types.
399+
if rhs_ty.is_single_valued(self.db) || rhs_ty.is_union_of_single_valued(self.db) {
400+
// The fully-general (and more efficient) approach here would be to introduce a
401+
// `NeverEqualTo` type that can wrap a single-valued type, and then simply return
402+
// `~NeverEqualTo(rhs_ty)` here and let union/intersection builder sort it out. This is
403+
// how we handle `AlwaysTruthy` and `AlwaysFalsy`. But this means we have to deal with
404+
// this type everywhere, and possibly have it show up unsimplified in some cases, and
405+
// so we instead prefer to just do the simplification here. (Another hybrid option that
406+
// would be similar to this, but more efficient, would be to allow narrowing to return
407+
// something that is not a type, and handle this not-a-type in `symbol_from_bindings`,
408+
// instead of intersecting with a type.)
409+
410+
// Return `true` if it is possible for any two inhabitants of the given types to
411+
// compare equal to each other; otherwise return `false`.
412+
fn could_compare_equal<'db>(
413+
db: &'db dyn Db,
414+
left_ty: Type<'db>,
415+
right_ty: Type<'db>,
416+
) -> bool {
417+
if !left_ty.is_disjoint_from(db, right_ty) {
418+
// If types overlap, they have inhabitants in common; it's definitely possible
419+
// for an object to compare equal to itself.
420+
return true;
421+
}
422+
match (left_ty, right_ty) {
423+
// In order to be sure a union type cannot compare equal to another type, it
424+
// must be true that no element of the union can compare equal to that type.
425+
(Type::Union(union), _) => union
426+
.elements(db)
427+
.iter()
428+
.any(|ty| could_compare_equal(db, *ty, right_ty)),
429+
(_, Type::Union(union)) => union
430+
.elements(db)
431+
.iter()
432+
.any(|ty| could_compare_equal(db, left_ty, *ty)),
433+
// Boolean literals and int literals are disjoint, and single valued, and yet
434+
// `True == 1` and `False == 0`.
435+
(Type::BooleanLiteral(b), Type::IntLiteral(i))
436+
| (Type::IntLiteral(i), Type::BooleanLiteral(b)) => i64::from(b) == i,
437+
// Other than the above cases, two single-valued disjoint types cannot compare
438+
// equal.
439+
_ => !(left_ty.is_single_valued(db) && right_ty.is_single_valued(db)),
440+
}
441+
}
442+
443+
// Return `true` if `lhs_ty` consists only of `LiteralString` and types that cannot
444+
// compare equal to `rhs_ty`.
445+
fn can_narrow_to_rhs<'db>(
446+
db: &'db dyn Db,
447+
lhs_ty: Type<'db>,
448+
rhs_ty: Type<'db>,
449+
) -> bool {
450+
match lhs_ty {
451+
Type::Union(union) => union
452+
.elements(db)
453+
.iter()
454+
.all(|ty| can_narrow_to_rhs(db, *ty, rhs_ty)),
455+
// Either `rhs_ty` is a string literal, in which case we can narrow to it (no
456+
// other string literal could compare equal to it), or it is not a string
457+
// literal, in which case (given that it is single-valued), LiteralString
458+
// cannot compare equal to it.
459+
Type::LiteralString => true,
460+
_ => !could_compare_equal(db, lhs_ty, rhs_ty),
461+
}
462+
}
463+
464+
// Filter `ty` to just the types that cannot be equal to `rhs_ty`.
465+
fn filter_to_cannot_be_equal<'db>(
466+
db: &'db dyn Db,
467+
ty: Type<'db>,
468+
rhs_ty: Type<'db>,
469+
) -> Type<'db> {
470+
match ty {
471+
Type::Union(union) => {
472+
union.map(db, |ty| filter_to_cannot_be_equal(db, *ty, rhs_ty))
473+
}
474+
// Treat `bool` as `Literal[True, False]`.
475+
Type::Instance(instance) if instance.class().is_known(db, KnownClass::Bool) => {
476+
UnionType::from_elements(
477+
db,
478+
[Type::BooleanLiteral(true), Type::BooleanLiteral(false)]
479+
.into_iter()
480+
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
481+
)
482+
}
483+
_ => {
484+
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
485+
ty
486+
} else {
487+
Type::Never
488+
}
489+
}
490+
}
491+
}
492+
Some(if can_narrow_to_rhs(self.db, lhs_ty, rhs_ty) {
493+
rhs_ty
494+
} else {
495+
filter_to_cannot_be_equal(self.db, lhs_ty, rhs_ty).negate(self.db)
496+
})
497+
} else {
498+
None
499+
}
500+
}
501+
502+
fn evaluate_expr_ne(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
503+
match (lhs_ty, rhs_ty) {
504+
(Type::Instance(instance), Type::IntLiteral(i))
505+
if instance.class().is_known(self.db, KnownClass::Bool) =>
506+
{
507+
if i == 0 {
508+
Some(Type::BooleanLiteral(false).negate(self.db))
509+
} else if i == 1 {
510+
Some(Type::BooleanLiteral(true).negate(self.db))
511+
} else {
512+
None
513+
}
514+
}
515+
(_, Type::BooleanLiteral(b)) => {
516+
if b {
517+
Some(
518+
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(1)])
519+
.negate(self.db),
520+
)
521+
} else {
522+
Some(
523+
UnionType::from_elements(self.db, [rhs_ty, Type::IntLiteral(0)])
524+
.negate(self.db),
525+
)
526+
}
527+
}
528+
_ if rhs_ty.is_single_valued(self.db) => Some(rhs_ty.negate(self.db)),
529+
_ => None,
530+
}
531+
}
532+
397533
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
398534
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
399535
match rhs_ty {
@@ -435,17 +571,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
435571
}
436572
}
437573
ast::CmpOp::Is => Some(rhs_ty),
438-
ast::CmpOp::NotEq => {
439-
if rhs_ty.is_single_valued(self.db) {
440-
let ty = IntersectionBuilder::new(self.db)
441-
.add_negative(rhs_ty)
442-
.build();
443-
Some(ty)
444-
} else {
445-
None
446-
}
447-
}
448-
ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty),
574+
ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty),
575+
ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty),
449576
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
450577
ast::CmpOp::NotIn => self
451578
.evaluate_expr_in(lhs_ty, rhs_ty)

0 commit comments

Comments
 (0)