Skip to content

-Z trait-solver=next: Deduplicate region constraints in query responses #111172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use rustc_span::DUMMY_SP;
use std::iter;
use std::ops::Deref;

mod dedup_solver;

impl<'tcx> EvalCtxt<'_, 'tcx> {
/// Canonicalizes the goal remembering the original values
/// for each bound variable.
Expand Down Expand Up @@ -140,7 +142,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
// Cannot use `take_registered_region_obligations` as we may compute the response
// inside of a `probe` whenever we have multiple choices inside of the solver.
let region_obligations = self.infcx.inner.borrow().region_obligations().to_owned();
let region_constraints = self.infcx.with_region_constraints(|region_constraints| {
let mut region_constraints = self.infcx.with_region_constraints(|region_constraints| {
make_query_region_constraints(
self.tcx(),
region_obligations
Expand All @@ -149,6 +151,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
region_constraints,
)
});
dedup_solver::Deduper::dedup(self.infcx, self.max_input_universe, &mut region_constraints);

let mut opaque_types = self.infcx.clone_opaque_types_for_query_response();
// Only return opaque type keys for newly-defined opaques
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use crate::infer::canonical::QueryRegionConstraints;
use crate::infer::region_constraints::MemberConstraint;
use rustc_infer::infer::InferCtxt;
use rustc_middle::ty;
use ty::{subst::GenericArg, Region, UniverseIndex};

use rustc_data_structures::fx::{FxHashMap, FxIndexMap, FxIndexSet};
use rustc_index::IndexVec;
use std::hash::Hash;

mod constraint_walker;
mod solver;
use constraint_walker::{DedupWalker, DedupableIndexer};
use solver::{ConstraintIndex, DedupSolver, VarIndex};

pub struct Deduper<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
max_nameable_universe: UniverseIndex,

var_indexer: DedupableIndexer<'tcx>,
constraint_vars: IndexVec<ConstraintIndex, Vec<VarIndex>>,
/// Constraints that are identical except for the value of their variables are grouped into the same clique
constraint_cliques: FxIndexMap<ConstraintType<'tcx>, Vec<ConstraintIndex>>,
/// Maps a constraint index (the index inside constraint_vars) back to its index in outlives
indx_to_outlives: FxHashMap<ConstraintIndex, usize>,
/// Maps a constraint index (the index inside constraint_vars) back to its index in member_constraints
indx_to_members: FxHashMap<ConstraintIndex, usize>,
}
#[derive(Debug, PartialEq, Eq, Hash)]
enum ConstraintType<'tcx> {
Outlives(Outlives<'tcx>),
Member(MemberConstraint<'tcx>),
}
pub type Outlives<'tcx> = ty::OutlivesPredicate<GenericArg<'tcx>, Region<'tcx>>;

impl<'a, 'tcx> Deduper<'a, 'tcx> {
pub fn dedup(
infcx: &'a InferCtxt<'tcx>,
max_nameable_universe: UniverseIndex,
constraints: &mut QueryRegionConstraints<'tcx>,
) {
let mut deduper = Self {
infcx,
max_nameable_universe,

var_indexer: DedupableIndexer::new(),
constraint_vars: IndexVec::default(),
constraint_cliques: FxIndexMap::default(),
indx_to_outlives: FxHashMap::default(),
indx_to_members: FxHashMap::default(),
};
deduper.dedup_internal(constraints);
}
fn dedup_internal(&mut self, constraints: &mut QueryRegionConstraints<'tcx>) {
fn dedup_exact<T: Clone + Hash + Eq>(input: &mut Vec<T>) {
*input = FxIndexSet::<T>::from_iter(input.clone()).into_iter().collect();
}
dedup_exact(&mut constraints.outlives);
dedup_exact(&mut constraints.member_constraints);

self.lower_constraints_into_solver(constraints);
let constraint_vars = std::mem::take(&mut self.constraint_vars);
let constraint_cliques =
std::mem::take(&mut self.constraint_cliques).into_iter().map(|x| x.1).collect();
let var_universes = std::mem::take(&mut self.var_indexer.var_universes)
.into_iter()
.map(|(var, uni)| (VarIndex::from(var), uni.index()))
.collect();
let removed = DedupSolver::dedup(constraint_vars, constraint_cliques, var_universes)
.removed_constraints;

let mut removed_outlives =
removed.iter().filter_map(|x| self.indx_to_outlives.get(x)).collect::<Vec<_>>();
let mut removed_members =
removed.iter().filter_map(|x| self.indx_to_members.get(x)).collect::<Vec<_>>();
removed_outlives.sort();
removed_members.sort();

for removed_outlive in removed_outlives.into_iter().rev() {
constraints.outlives.swap_remove(*removed_outlive);
}
for removed_member in removed_members.into_iter().rev() {
constraints.member_constraints.swap_remove(*removed_member);
}
}
fn lower_constraints_into_solver(&mut self, constraints: &QueryRegionConstraints<'tcx>) {
for (outlives_indx, outlives) in constraints.outlives.iter().enumerate() {
let (erased, vars) = DedupWalker::erase_dedupables(
self.infcx,
&mut self.var_indexer,
self.max_nameable_universe,
outlives.0.clone(),
);
self.insert_constraint(vars, ConstraintType::Outlives(erased), outlives_indx);
}
for (member_indx, member) in constraints.member_constraints.iter().enumerate() {
let (erased, vars) = DedupWalker::erase_dedupables(
self.infcx,
&mut self.var_indexer,
self.max_nameable_universe,
member.clone(),
);
self.insert_constraint(vars, ConstraintType::Member(erased), member_indx);
}
}
fn insert_constraint(
&mut self,
vars: Vec<usize>,
erased: ConstraintType<'tcx>,
original_indx: usize,
) {
if vars.is_empty() {
return;
}
let constraint_indx = self.constraint_vars.next_index();
match erased {
ConstraintType::Outlives(_) => {
self.indx_to_outlives.insert(constraint_indx, original_indx)
}
ConstraintType::Member(_) => {
self.indx_to_members.insert(constraint_indx, original_indx)
}
};
self.constraint_vars.push(vars.into_iter().map(VarIndex::from).collect());
self.constraint_cliques.entry(erased).or_insert_with(Vec::new).push(constraint_indx);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_infer::infer::InferCtxt;
use rustc_middle::ty;
use rustc_middle::ty::{
Const, GenericArg, Region, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
};

pub struct DedupWalker<'me, 'tcx> {
infcx: &'me InferCtxt<'tcx>,
var_indexer: &'me mut DedupableIndexer<'tcx>,
max_nameable_universe: ty::UniverseIndex,

vars_present: Vec<usize>,
}
pub struct DedupableIndexer<'tcx> {
vars: FxIndexSet<GenericArg<'tcx>>,
pub var_universes: FxIndexMap<usize, ty::UniverseIndex>,
}

impl<'me, 'tcx> DedupWalker<'me, 'tcx> {
pub fn erase_dedupables<T: TypeFoldable<TyCtxt<'tcx>>>(
infcx: &'me InferCtxt<'tcx>,
var_indexer: &'me mut DedupableIndexer<'tcx>,
max_nameable_universe: ty::UniverseIndex,
value: T,
) -> (T, Vec<usize>) {
let mut dedup_walker =
Self { infcx, var_indexer, max_nameable_universe, vars_present: Vec::new() };
let folded = value.fold_with(&mut dedup_walker);
(folded, dedup_walker.vars_present)
}
}
impl<'tcx> DedupableIndexer<'tcx> {
pub fn new() -> Self {
Self { vars: FxIndexSet::default(), var_universes: FxIndexMap::default() }
}
fn lookup(&mut self, var: GenericArg<'tcx>, universe: ty::UniverseIndex) -> usize {
let var_indx = self.vars.get_index_of(&var).unwrap_or_else(|| self.vars.insert_full(var).0);
self.var_universes.insert(var_indx, universe);
var_indx
}
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for DedupWalker<'_, 'tcx> {
fn interner(&self) -> TyCtxt<'tcx> {
self.infcx.tcx
}

fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
&mut self,
t: ty::Binder<'tcx, T>,
) -> ty::Binder<'tcx, T> {
t.super_fold_with(self)
}

fn fold_region(&mut self, region: Region<'tcx>) -> Region<'tcx> {
let universe = match *region {
ty::ReVar(..) | ty::RePlaceholder(..) => self.infcx.universe_of_region(region),
_ => return region,
};
if self.max_nameable_universe.can_name(universe) {
return region;
}
let var_id = self.var_indexer.lookup(GenericArg::from(region), universe);
self.vars_present.push(var_id);
// dummy value
self.interner().mk_re_placeholder(ty::Placeholder {
universe: ty::UniverseIndex::from(self.max_nameable_universe.index() + 1),
bound: ty::BoundRegion {
var: ty::BoundVar::from_usize(0),
kind: ty::BoundRegionKind::BrEnv,
},
})
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
let universe = match *ty.kind() {
ty::Placeholder(p) => p.universe,
/*
ty::Infer(ty::InferTy::TyVar(vid)) => {
if let Err(uni) = self.infcx.probe_ty_var(vid) {
uni
} else {
return ty;
}
}
*/
_ => return ty,
};
if self.max_nameable_universe.can_name(universe) {
return ty;
}
let var_id = self.var_indexer.lookup(GenericArg::from(ty), universe);
self.vars_present.push(var_id);
// dummy value
self.interner().mk_ty_from_kind(ty::Placeholder(ty::Placeholder {
universe: ty::UniverseIndex::from(self.max_nameable_universe.index() + 1),
bound: ty::BoundTy { var: ty::BoundVar::from_usize(0), kind: ty::BoundTyKind::Anon },
}))
}

fn fold_const(&mut self, ct: Const<'tcx>) -> Const<'tcx> {
let new_ty = self.fold_ty(ct.ty());
let universe = match ct.kind() {
/*
ty::ConstKind::Infer(ty::InferConst::Var(vid)) => {
if let Err(uni) = self.infcx.probe_const_var(vid) { Some(uni) } else { None }
}
*/
ty::ConstKind::Placeholder(p) => Some(p.universe),
_ => None,
};
let new_const_kind = if let Some(uni) = universe {
if self.max_nameable_universe.can_name(uni) {
ct.kind()
} else {
let var_id = self.var_indexer.lookup(GenericArg::from(ct), uni);
self.vars_present.push(var_id);
// dummy value
ty::ConstKind::Placeholder(ty::Placeholder {
universe: ty::UniverseIndex::from(self.max_nameable_universe.index() + 1),
bound: ty::BoundVar::from_usize(0),
})
}
} else {
ct.kind()
};
self.infcx.tcx.mk_const(new_const_kind, new_ty)
}
}
Loading