Skip to content

refactor fudge_inference, handle effect vars #131911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,13 @@ impl<'tcx> InferCtxt<'tcx> {
ty::Const::new_var(self.tcx, vid)
}

fn next_effect_var(&self) -> ty::Const<'tcx> {
let effect_vid =
self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;

ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(effect_vid))
}

pub fn next_int_var(&self) -> Ty<'tcx> {
let next_int_var_id =
self.inner.borrow_mut().int_unification_table().new_key(ty::IntVarValue::Unknown);
Expand Down Expand Up @@ -1001,15 +1008,13 @@ impl<'tcx> InferCtxt<'tcx> {
}

pub fn var_for_effect(&self, param: &ty::GenericParamDef) -> GenericArg<'tcx> {
let effect_vid =
self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;
let ty = self
.tcx
.type_of(param.def_id)
.no_bound_vars()
.expect("const parameter types cannot be generic");
debug_assert_eq!(self.tcx.types.bool, ty);
ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(effect_vid)).into()
self.next_effect_var().into()
}

/// Given a set of generics defined on a type or impl, returns the generic parameters mapping
Expand Down
267 changes: 146 additions & 121 deletions compiler/rustc_infer/src/infer/snapshot/fudge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use rustc_data_structures::{snapshot_vec as sv, unify as ut};
use rustc_middle::infer::unify_key::{ConstVariableValue, ConstVidKey};
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};
use rustc_type_ir::EffectVid;
use rustc_type_ir::visit::TypeVisitableExt;
use tracing::instrument;
use ut::UnifyKey;

use super::VariableLengths;
use crate::infer::type_variable::TypeVariableOrigin;
use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};

Expand Down Expand Up @@ -40,26 +43,7 @@ fn const_vars_since_snapshot<'tcx>(
)
}

struct VariableLengths {
type_var_len: usize,
const_var_len: usize,
int_var_len: usize,
float_var_len: usize,
region_constraints_len: usize,
}

impl<'tcx> InferCtxt<'tcx> {
fn variable_lengths(&self) -> VariableLengths {
let mut inner = self.inner.borrow_mut();
VariableLengths {
type_var_len: inner.type_variables().num_vars(),
const_var_len: inner.const_unification_table().len(),
int_var_len: inner.int_unification_table().len(),
float_var_len: inner.float_unification_table().len(),
region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
}
}

/// This rather funky routine is used while processing expected
/// types. What happens here is that we want to propagate a
/// coercion through the return type of a fn to its
Expand Down Expand Up @@ -106,78 +90,94 @@ impl<'tcx> InferCtxt<'tcx> {
T: TypeFoldable<TyCtxt<'tcx>>,
{
let variable_lengths = self.variable_lengths();
let (mut fudger, value) = self.probe(|_| {
match f() {
Ok(value) => {
let value = self.resolve_vars_if_possible(value);

// At this point, `value` could in principle refer
// to inference variables that have been created during
// the snapshot. Once we exit `probe()`, those are
// going to be popped, so we will have to
// eliminate any references to them.

let mut inner = self.inner.borrow_mut();
let type_vars =
inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
let int_vars = vars_since_snapshot(
&inner.int_unification_table(),
variable_lengths.int_var_len,
);
let float_vars = vars_since_snapshot(
&inner.float_unification_table(),
variable_lengths.float_var_len,
);
let region_vars = inner
.unwrap_region_constraints()
.vars_since_snapshot(variable_lengths.region_constraints_len);
let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
variable_lengths.const_var_len,
);

let fudger = InferenceFudger {
infcx: self,
type_vars,
int_vars,
float_vars,
region_vars,
const_vars,
};

Ok((fudger, value))
}
Err(e) => Err(e),
}
let (snapshot_vars, value) = self.probe(|_| {
let value = f()?;
// At this point, `value` could in principle refer
// to inference variables that have been created during
// the snapshot. Once we exit `probe()`, those are
// going to be popped, so we will have to
// eliminate any references to them.
let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
})?;

// At this point, we need to replace any of the now-popped
// type/region variables that appear in `value` with a fresh
// variable of the appropriate kind. We can't do this during
// the probe because they would just get popped then too. =)
Ok(self.fudge_inference(snapshot_vars, value))
}

fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
&self,
snapshot_vars: SnapshotVarData,
value: T,
) -> T {
// Micro-optimization: if no variables have been created, then
// `value` can't refer to any of them. =) So we can just return it.
if fudger.type_vars.0.is_empty()
&& fudger.int_vars.is_empty()
&& fudger.float_vars.is_empty()
&& fudger.region_vars.0.is_empty()
&& fudger.const_vars.0.is_empty()
{
Ok(value)
if snapshot_vars.is_empty() {
value
} else {
Ok(value.fold_with(&mut fudger))
value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
}
}
}

struct InferenceFudger<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
struct SnapshotVarData {
region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
int_vars: Range<IntVid>,
float_vars: Range<FloatVid>,
region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
effect_vars: Range<EffectVid>,
}

impl SnapshotVarData {
fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
let mut inner = infcx.inner.borrow_mut();
let region_vars = inner
.unwrap_region_constraints()
.vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
let int_vars =
vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
let float_vars =
vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);

let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
vars_pre_snapshot.const_var_len,
);
let effect_vars = vars_since_snapshot(
&inner.effect_unification_table(),
vars_pre_snapshot.effect_var_len,
);
let effect_vars = effect_vars.start.vid..effect_vars.end.vid;

SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars, effect_vars }
}

fn is_empty(&self) -> bool {
let SnapshotVarData {
region_vars,
type_vars,
int_vars,
float_vars,
const_vars,
effect_vars,
} = self;
region_vars.0.is_empty()
&& type_vars.0.is_empty()
&& int_vars.is_empty()
&& float_vars.is_empty()
&& const_vars.0.is_empty()
&& effect_vars.is_empty()
}
}

struct InferenceFudger<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
snapshot_vars: SnapshotVarData,
}

impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
Expand All @@ -186,68 +186,93 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
match *ty.kind() {
ty::Infer(ty::InferTy::TyVar(vid)) => {
if self.type_vars.0.contains(&vid) {
// This variable was created during the fudging.
// Recreate it with a fresh variable here.
let idx = vid.as_usize() - self.type_vars.0.start.as_usize();
let origin = self.type_vars.1[idx];
self.infcx.next_ty_var_with_origin(origin)
} else {
// This variable was created before the
// "fudging". Since we refresh all type
// variables to their binding anyhow, we know
// that it is unbound, so we can just return
// it.
debug_assert!(
self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
);
ty
if let &ty::Infer(infer_ty) = ty.kind() {
match infer_ty {
ty::TyVar(vid) => {
if self.snapshot_vars.type_vars.0.contains(&vid) {
// This variable was created during the fudging.
// Recreate it with a fresh variable here.
let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
let origin = self.snapshot_vars.type_vars.1[idx];
self.infcx.next_ty_var_with_origin(origin)
} else {
// This variable was created before the
// "fudging". Since we refresh all type
// variables to their binding anyhow, we know
// that it is unbound, so we can just return
// it.
debug_assert!(
self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
);
ty
}
}
}
ty::Infer(ty::InferTy::IntVar(vid)) => {
if self.int_vars.contains(&vid) {
self.infcx.next_int_var()
} else {
ty
ty::IntVar(vid) => {
if self.snapshot_vars.int_vars.contains(&vid) {
self.infcx.next_int_var()
} else {
ty
}
}
}
ty::Infer(ty::InferTy::FloatVar(vid)) => {
if self.float_vars.contains(&vid) {
self.infcx.next_float_var()
} else {
ty
ty::FloatVar(vid) => {
if self.snapshot_vars.float_vars.contains(&vid) {
self.infcx.next_float_var()
} else {
ty
}
}
ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
unreachable!("unexpected fresh infcx var")
}
}
_ => ty.super_fold_with(self),
} else if ty.has_infer() {
ty.super_fold_with(self)
} else {
ty
}
}

fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
if let ty::ReVar(vid) = *r
&& self.region_vars.0.contains(&vid)
{
let idx = vid.index() - self.region_vars.0.start.index();
let origin = self.region_vars.1[idx];
return self.infcx.next_region_var(origin);
if let ty::ReVar(vid) = r.kind() {
if self.snapshot_vars.region_vars.0.contains(&vid) {
let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
let origin = self.snapshot_vars.region_vars.1[idx];
self.infcx.next_region_var(origin)
} else {
r
}
} else {
r
}
r
}

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
if let ty::ConstKind::Infer(ty::InferConst::Var(vid)) = ct.kind() {
if self.const_vars.0.contains(&vid) {
// This variable was created during the fudging.
// Recreate it with a fresh variable here.
let idx = vid.index() - self.const_vars.0.start.index();
let origin = self.const_vars.1[idx];
self.infcx.next_const_var_with_origin(origin)
} else {
ct
if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
match infer_ct {
ty::InferConst::Var(vid) => {
if self.snapshot_vars.const_vars.0.contains(&vid) {
let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
let origin = self.snapshot_vars.const_vars.1[idx];
self.infcx.next_const_var_with_origin(origin)
} else {
ct
}
}
ty::InferConst::EffectVar(vid) => {
if self.snapshot_vars.effect_vars.contains(&vid) {
self.infcx.next_effect_var()
} else {
ct
}
}
ty::InferConst::Fresh(_) => {
unreachable!("unexpected fresh infcx var")
}
}
} else {
} else if ct.has_infer() {
ct.super_fold_with(self)
} else {
ct
}
}
}
21 changes: 21 additions & 0 deletions compiler/rustc_infer/src/infer/snapshot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,28 @@ pub struct CombinedSnapshot<'tcx> {
universe: ty::UniverseIndex,
}

struct VariableLengths {
region_constraints_len: usize,
type_var_len: usize,
int_var_len: usize,
float_var_len: usize,
const_var_len: usize,
effect_var_len: usize,
}

impl<'tcx> InferCtxt<'tcx> {
fn variable_lengths(&self) -> VariableLengths {
let mut inner = self.inner.borrow_mut();
VariableLengths {
region_constraints_len: inner.unwrap_region_constraints().num_region_vars(),
type_var_len: inner.type_variables().num_vars(),
int_var_len: inner.int_unification_table().len(),
float_var_len: inner.float_unification_table().len(),
const_var_len: inner.const_unification_table().len(),
effect_var_len: inner.effect_unification_table().len(),
}
}

pub fn in_snapshot(&self) -> bool {
UndoLogs::<UndoLog<'tcx>>::in_snapshot(&self.inner.borrow_mut().undo_log)
}
Expand Down
Loading