Skip to content

Fold predicate fast path in canonicalizer and eager resolver #141442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,10 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
ct
}
}

fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
}
}

impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
Expand Down
59 changes: 51 additions & 8 deletions compiler/rustc_next_trait_solver/src/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@ use rustc_type_ir::data_structures::{HashMap, ensure_sufficient_stack};
use rustc_type_ir::inherent::*;
use rustc_type_ir::solve::{Goal, QueryInput};
use rustc_type_ir::{
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, InferCtxtLike, Interner,
TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, Flags, InferCtxtLike, Interner,
TypeFlags, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
};

use crate::delegate::SolverDelegate;

/// Does this have infer/placeholder/param, free regions or ReErased?
const NEEDS_CANONICAL: TypeFlags = TypeFlags::from_bits(
TypeFlags::HAS_INFER.bits()
| TypeFlags::HAS_PLACEHOLDER.bits()
| TypeFlags::HAS_PARAM.bits()
| TypeFlags::HAS_FREE_REGIONS.bits()
| TypeFlags::HAS_RE_ERASED.bits(),
)
.unwrap();

/// Whether we're canonicalizing a query input or the query response.
///
/// When canonicalizing an input we're in the context of the caller
Expand Down Expand Up @@ -79,7 +89,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
cache: Default::default(),
};

let value = value.fold_with(&mut canonicalizer);
let value = if value.has_type_flags(NEEDS_CANONICAL) {
value.fold_with(&mut canonicalizer)
} else {
value
};
assert!(!value.has_infer(), "unexpected infer in {value:?}");
assert!(!value.has_placeholders(), "unexpected placeholders in {value:?}");
let (max_universe, variables) = canonicalizer.finalize();
Expand Down Expand Up @@ -111,7 +125,14 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {

cache: Default::default(),
};
let param_env = input.goal.param_env.fold_with(&mut env_canonicalizer);

let param_env = input.goal.param_env;
let param_env = if param_env.has_type_flags(NEEDS_CANONICAL) {
param_env.fold_with(&mut env_canonicalizer)
} else {
param_env
};

debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST);
// Then canonicalize the rest of the input without keeping `'static`
// while *mostly* reusing the canonicalizer from above.
Expand All @@ -134,10 +155,22 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
cache: Default::default(),
};

let predicate = input.goal.predicate.fold_with(&mut rest_canonicalizer);
let predicate = input.goal.predicate;
let predicate = if predicate.has_type_flags(NEEDS_CANONICAL) {
predicate.fold_with(&mut rest_canonicalizer)
} else {
predicate
};
let goal = Goal { param_env, predicate };

let predefined_opaques_in_body = input.predefined_opaques_in_body;
let predefined_opaques_in_body =
input.predefined_opaques_in_body.fold_with(&mut rest_canonicalizer);
if input.predefined_opaques_in_body.has_type_flags(NEEDS_CANONICAL) {
predefined_opaques_in_body.fold_with(&mut rest_canonicalizer)
} else {
predefined_opaques_in_body
};

let value = QueryInput { goal, predefined_opaques_in_body };

assert!(!value.has_infer(), "unexpected infer in {value:?}");
Expand Down Expand Up @@ -387,7 +420,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
| ty::Alias(_, _)
| ty::Bound(_, _)
| ty::Error(_) => {
return ensure_sufficient_stack(|| t.super_fold_with(self));
return if t.has_type_flags(NEEDS_CANONICAL) {
ensure_sufficient_stack(|| t.super_fold_with(self))
} else {
t
};
}
};

Expand Down Expand Up @@ -522,11 +559,17 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
| ty::ConstKind::Unevaluated(_)
| ty::ConstKind::Value(_)
| ty::ConstKind::Error(_)
| ty::ConstKind::Expr(_) => return c.super_fold_with(self),
| ty::ConstKind::Expr(_) => {
return if c.has_type_flags(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c };
}
};

let var = self.get_or_insert_bound_var(c, kind);

Const::new_anon_bound(self.cx(), self.binder_index, var)
}

fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_next_trait_solver/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,8 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
}
}
}

fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_infer() { p.super_fold_with(self) } else { p }
Copy link
Contributor

@lcnr lcnr May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc #141451 we prolly want to also add a fn fold_clauses for this, gonna impl that (or a general fold_with_type_flags) tmrw

Copy link
Member Author

@compiler-errors compiler-errors May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that's necessary. fold for clause goes through fold_predicate by some .as_clause() + fold + .expect_clause() punning today, AFAICT.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clauses, avoid walking over 150 where-bounds in diesel :3

}
}
Loading