Skip to content

Commit ed70264

Browse files
committed
Auto merge of #141442 - compiler-errors:fast-path-pred, r=<try>
Fold predicate fast path in canonicalizer and eager resolver r? lcnr
2 parents 52bf0cf + f96b8d5 commit ed70264

File tree

4 files changed

+64
-9
lines changed

4 files changed

+64
-9
lines changed

compiler/rustc_infer/src/infer/canonical/canonicalizer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,10 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
515515
ct
516516
}
517517
}
518+
519+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
520+
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
521+
}
518522
}
519523

520524
impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {

compiler/rustc_next_trait_solver/src/canonicalizer.rs

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ use rustc_type_ir::data_structures::{HashMap, ensure_sufficient_stack};
44
use rustc_type_ir::inherent::*;
55
use rustc_type_ir::solve::{Goal, QueryInput};
66
use rustc_type_ir::{
7-
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarInfo, CanonicalVarKind, InferCtxtLike,
8-
Interner, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
7+
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarInfo, CanonicalVarKind, Flags,
8+
InferCtxtLike, Interner, TypeFlags, TypeFoldable, TypeFolder, TypeSuperFoldable,
9+
TypeVisitableExt,
910
};
1011

1112
use crate::delegate::SolverDelegate;
@@ -79,7 +80,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
7980
cache: Default::default(),
8081
};
8182

82-
let value = value.fold_with(&mut canonicalizer);
83+
let value = if value.has_type_flags(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER) {
84+
value.fold_with(&mut canonicalizer)
85+
} else {
86+
value
87+
};
8388
assert!(!value.has_infer(), "unexpected infer in {value:?}");
8489
assert!(!value.has_placeholders(), "unexpected placeholders in {value:?}");
8590
let (max_universe, variables) = canonicalizer.finalize();
@@ -111,7 +116,14 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
111116

112117
cache: Default::default(),
113118
};
114-
let param_env = input.goal.param_env.fold_with(&mut env_canonicalizer);
119+
120+
let param_env = input.goal.param_env;
121+
let param_env = if param_env.has_type_flags(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER) {
122+
param_env.fold_with(&mut env_canonicalizer)
123+
} else {
124+
param_env
125+
};
126+
115127
debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST);
116128
// Then canonicalize the rest of the input without keeping `'static`
117129
// while *mostly* reusing the canonicalizer from above.
@@ -134,10 +146,24 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
134146
cache: Default::default(),
135147
};
136148

137-
let predicate = input.goal.predicate.fold_with(&mut rest_canonicalizer);
149+
let predicate = input.goal.predicate;
150+
let predicate = if predicate.has_type_flags(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER) {
151+
predicate.fold_with(&mut rest_canonicalizer)
152+
} else {
153+
predicate
154+
};
138155
let goal = Goal { param_env, predicate };
139-
let predefined_opaques_in_body =
140-
input.predefined_opaques_in_body.fold_with(&mut rest_canonicalizer);
156+
157+
let predefined_opaques_in_body = input.predefined_opaques_in_body;
158+
let predefined_opaques_in_body = if input
159+
.predefined_opaques_in_body
160+
.has_type_flags(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER)
161+
{
162+
predefined_opaques_in_body.fold_with(&mut rest_canonicalizer)
163+
} else {
164+
predefined_opaques_in_body
165+
};
166+
141167
let value = QueryInput { goal, predefined_opaques_in_body };
142168

143169
assert!(!value.has_infer(), "unexpected infer in {value:?}");
@@ -387,7 +413,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
387413
| ty::Alias(_, _)
388414
| ty::Bound(_, _)
389415
| ty::Error(_) => {
390-
return ensure_sufficient_stack(|| t.super_fold_with(self));
416+
return if t.has_type_flags(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER) {
417+
ensure_sufficient_stack(|| t.super_fold_with(self))
418+
} else {
419+
t
420+
};
391421
}
392422
};
393423

@@ -522,11 +552,25 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
522552
| ty::ConstKind::Unevaluated(_)
523553
| ty::ConstKind::Value(_)
524554
| ty::ConstKind::Error(_)
525-
| ty::ConstKind::Expr(_) => return c.super_fold_with(self),
555+
| ty::ConstKind::Expr(_) => {
556+
return if c.has_type_flags(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER) {
557+
c.super_fold_with(self)
558+
} else {
559+
c
560+
};
561+
}
526562
};
527563

528564
let var = self.get_or_insert_bound_var(c, CanonicalVarInfo { kind });
529565

530566
Const::new_anon_bound(self.cx(), self.binder_index, var)
531567
}
568+
569+
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
570+
if p.flags().intersects(TypeFlags::NEEDS_CANONICALIZATION_NEXT_SOLVER) {
571+
p.super_fold_with(self)
572+
} else {
573+
p
574+
}
575+
}
532576
}

compiler/rustc_next_trait_solver/src/resolve.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,8 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
8686
}
8787
}
8888
}
89+
90+
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
91+
if p.has_infer() { p.super_fold_with(self) } else { p }
92+
}
8993
}

compiler/rustc_type_ir/src/flags.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ bitflags::bitflags! {
99
/// over the type itself.
1010
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
1111
pub struct TypeFlags: u32 {
12+
// TODUwU
13+
const NEEDS_CANONICALIZATION_NEXT_SOLVER = TypeFlags::HAS_INFER.bits() | TypeFlags::HAS_PLACEHOLDER.bits() | TypeFlags::HAS_FREE_REGIONS.bits() | TypeFlags::HAS_PARAM.bits() | TypeFlags::HAS_RE_ERASED.bits();
14+
1215
// Does this have parameters? Used to determine whether instantiation is
1316
// required.
1417
/// Does this have `Param`?

0 commit comments

Comments
 (0)