Skip to content

Commit 03878c6

Browse files
committed
hir typeck: look into nested goals
uses a `ProofTreeVisitor` to look into nested goals when looking at the pending obligations during hir typeck. Used by closure signature inference, coercion, and for async functions.
1 parent 662eadb commit 03878c6

23 files changed

+632
-317
lines changed

compiler/rustc_hir_typeck/src/closure.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
342342
ty::Infer(ty::TyVar(vid)) => self.deduce_closure_signature_from_predicates(
343343
Ty::new_var(self.tcx, self.root_var(vid)),
344344
closure_kind,
345-
self.obligations_for_self_ty(vid).map(|obl| (obl.predicate, obl.cause.span)),
345+
self.obligations_for_self_ty(vid)
346+
.into_iter()
347+
.map(|obl| (obl.predicate, obl.cause.span)),
346348
),
347349
ty::FnPtr(sig) => match closure_kind {
348350
hir::ClosureKind::Closure => {
@@ -889,7 +891,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
889891

890892
let output_ty = match *ret_ty.kind() {
891893
ty::Infer(ty::TyVar(ret_vid)) => {
892-
self.obligations_for_self_ty(ret_vid).find_map(|obligation| {
894+
self.obligations_for_self_ty(ret_vid).into_iter().find_map(|obligation| {
893895
get_future_output(obligation.predicate, obligation.cause.span)
894896
})?
895897
}

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
637637

638638
pub(crate) fn type_var_is_sized(&self, self_ty: ty::TyVid) -> bool {
639639
let sized_did = self.tcx.lang_items().sized_trait();
640-
self.obligations_for_self_ty(self_ty).any(|obligation| {
640+
self.obligations_for_self_ty(self_ty).into_iter().any(|obligation| {
641641
match obligation.predicate.kind().skip_binder() {
642642
ty::PredicateKind::Clause(ty::ClauseKind::Trait(data)) => {
643643
Some(data.def_id()) == sized_did
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,140 @@
1+
//! A utility module to inspect currently ambiguous obligations in the current context.
2+
use crate::rustc_middle::ty::TypeVisitableExt;
13
use crate::FnCtxt;
4+
use rustc_infer::traits::solve::Goal;
5+
use rustc_infer::traits::{self, ObligationCause};
26
use rustc_middle::ty::{self, Ty};
3-
use rustc_infer::traits;
4-
use rustc_data_structures::captures::Captures;
7+
use rustc_span::Span;
8+
use rustc_trait_selection::solve::inspect::ProofTreeInferCtxtExt;
9+
use rustc_trait_selection::solve::inspect::{InspectConfig, InspectGoal, ProofTreeVisitor};
510

611
impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
12+
/// Returns a list of all obligations whose self type has been unified
13+
/// with the unconstrained type `self_ty`.
714
#[instrument(skip(self), level = "debug")]
8-
pub(crate) fn obligations_for_self_ty<'b>(
9-
&'b self,
15+
pub(crate) fn obligations_for_self_ty(
16+
&self,
1017
self_ty: ty::TyVid,
11-
) -> impl DoubleEndedIterator<Item = traits::PredicateObligation<'tcx>> + Captures<'tcx> + 'b
12-
{
13-
let ty_var_root = self.root_var(self_ty);
14-
trace!("pending_obligations = {:#?}", self.fulfillment_cx.borrow().pending_obligations());
15-
16-
self.fulfillment_cx.borrow().pending_obligations().into_iter().filter_map(
17-
move |obligation| match &obligation.predicate.kind().skip_binder() {
18-
ty::PredicateKind::Clause(ty::ClauseKind::Projection(data))
19-
if self.self_type_matches_expected_vid(
20-
data.projection_ty.self_ty(),
21-
ty_var_root,
22-
) =>
23-
{
24-
Some(obligation)
25-
}
26-
ty::PredicateKind::Clause(ty::ClauseKind::Trait(data))
27-
if self.self_type_matches_expected_vid(data.self_ty(), ty_var_root) =>
28-
{
29-
Some(obligation)
30-
}
18+
) -> Vec<traits::PredicateObligation<'tcx>> {
19+
if self.next_trait_solver() {
20+
self.obligations_for_self_ty_next(self_ty)
21+
} else {
22+
let ty_var_root = self.root_var(self_ty);
23+
let mut obligations = self.fulfillment_cx.borrow().pending_obligations();
24+
trace!("pending_obligations = {:#?}", obligations);
25+
obligations
26+
.retain(|obligation| self.predicate_has_self_ty(obligation.predicate, ty_var_root));
27+
obligations
28+
}
29+
}
3130

32-
ty::PredicateKind::Clause(ty::ClauseKind::Trait(..))
33-
| ty::PredicateKind::Clause(ty::ClauseKind::Projection(..))
34-
| ty::PredicateKind::Clause(ty::ClauseKind::ConstArgHasType(..))
35-
| ty::PredicateKind::Subtype(..)
36-
| ty::PredicateKind::Coerce(..)
37-
| ty::PredicateKind::Clause(ty::ClauseKind::RegionOutlives(..))
38-
| ty::PredicateKind::Clause(ty::ClauseKind::TypeOutlives(..))
39-
| ty::PredicateKind::Clause(ty::ClauseKind::WellFormed(..))
40-
| ty::PredicateKind::ObjectSafe(..)
41-
| ty::PredicateKind::NormalizesTo(..)
42-
| ty::PredicateKind::AliasRelate(..)
43-
| ty::PredicateKind::Clause(ty::ClauseKind::ConstEvaluatable(..))
44-
| ty::PredicateKind::ConstEquate(..)
45-
| ty::PredicateKind::Ambiguous => None,
46-
},
47-
)
31+
#[instrument(level = "debug", skip(self), ret)]
32+
fn predicate_has_self_ty(
33+
&self,
34+
predicate: ty::Predicate<'tcx>,
35+
expected_vid: ty::TyVid,
36+
) -> bool {
37+
match predicate.kind().skip_binder() {
38+
ty::PredicateKind::Clause(ty::ClauseKind::Trait(data)) => {
39+
self.type_matches_expected_vid(expected_vid, data.self_ty())
40+
}
41+
ty::PredicateKind::Clause(ty::ClauseKind::Projection(data)) => {
42+
self.type_matches_expected_vid(expected_vid, data.projection_ty.self_ty())
43+
}
44+
ty::PredicateKind::Clause(ty::ClauseKind::ConstArgHasType(..))
45+
| ty::PredicateKind::Subtype(..)
46+
| ty::PredicateKind::Coerce(..)
47+
| ty::PredicateKind::Clause(ty::ClauseKind::RegionOutlives(..))
48+
| ty::PredicateKind::Clause(ty::ClauseKind::TypeOutlives(..))
49+
| ty::PredicateKind::Clause(ty::ClauseKind::WellFormed(..))
50+
| ty::PredicateKind::ObjectSafe(..)
51+
| ty::PredicateKind::NormalizesTo(..)
52+
| ty::PredicateKind::AliasRelate(..)
53+
| ty::PredicateKind::Clause(ty::ClauseKind::ConstEvaluatable(..))
54+
| ty::PredicateKind::ConstEquate(..)
55+
| ty::PredicateKind::Ambiguous => false,
56+
}
4857
}
4958

5059
#[instrument(level = "debug", skip(self), ret)]
51-
fn self_type_matches_expected_vid(&self, self_ty: Ty<'tcx>, expected_vid: ty::TyVid) -> bool {
52-
let self_ty = self.shallow_resolve(self_ty);
53-
debug!(?self_ty);
60+
fn type_matches_expected_vid(&self, expected_vid: ty::TyVid, ty: Ty<'tcx>) -> bool {
61+
let ty = self.shallow_resolve(ty);
62+
debug!(?ty);
5463

55-
match *self_ty.kind() {
64+
match *ty.kind() {
5665
ty::Infer(ty::TyVar(found_vid)) => {
57-
let found_vid = self.root_var(found_vid);
58-
debug!("self_type_matches_expected_vid - found_vid={:?}", found_vid);
59-
expected_vid == found_vid
66+
self.root_var(expected_vid) == self.root_var(found_vid)
6067
}
6168
_ => false,
6269
}
6370
}
64-
}
71+
72+
pub(crate) fn obligations_for_self_ty_next(
73+
&self,
74+
self_ty: ty::TyVid,
75+
) -> Vec<traits::PredicateObligation<'tcx>> {
76+
let obligations = self.fulfillment_cx.borrow().pending_obligations();
77+
debug!(?obligations);
78+
let mut obligations_for_self_ty = vec![];
79+
for obligation in obligations {
80+
let mut visitor = NestedObligationsForSelfTy {
81+
fcx: self,
82+
self_ty,
83+
obligations_for_self_ty: &mut obligations_for_self_ty,
84+
root_cause: &obligation.cause,
85+
};
86+
87+
let goal = Goal::new(self.tcx, obligation.param_env, obligation.predicate);
88+
self.visit_proof_tree(goal, &mut visitor);
89+
}
90+
91+
obligations_for_self_ty.retain_mut(|obligation| {
92+
obligation.predicate = self.resolve_vars_if_possible(obligation.predicate);
93+
!obligation.predicate.has_placeholders()
94+
});
95+
obligations_for_self_ty
96+
}
97+
}
98+
99+
struct NestedObligationsForSelfTy<'a, 'tcx> {
100+
fcx: &'a FnCtxt<'a, 'tcx>,
101+
self_ty: ty::TyVid,
102+
root_cause: &'a ObligationCause<'tcx>,
103+
obligations_for_self_ty: &'a mut Vec<traits::PredicateObligation<'tcx>>,
104+
}
105+
106+
impl<'a, 'tcx> ProofTreeVisitor<'tcx> for NestedObligationsForSelfTy<'a, 'tcx> {
107+
type Result = ();
108+
109+
fn span(&self) -> Span {
110+
self.root_cause.span
111+
}
112+
113+
fn config(&self) -> InspectConfig {
114+
// Using an intentionally low depth to minimize the chance of future
115+
// breaking changes in case we adapt the approach later on. This also
116+
// avoids any hangs for exponentially growing proof trees.
117+
InspectConfig { max_depth: 5 }
118+
}
119+
120+
fn visit_goal(&mut self, inspect_goal: &InspectGoal<'_, 'tcx>) {
121+
let tcx = self.fcx.tcx;
122+
let goal = inspect_goal.goal();
123+
if self.fcx.predicate_has_self_ty(goal.predicate, self.self_ty) {
124+
self.obligations_for_self_ty.push(traits::Obligation::new(
125+
tcx,
126+
self.root_cause.clone(),
127+
goal.param_env,
128+
goal.predicate,
129+
));
130+
}
131+
132+
// If there's a unique way to prove a given goal, recurse into
133+
// that candidate. This means that for `impl<F: FnOnce(u32)> Trait<F> for () {}`
134+
// and a `(): Trait<?0>` goal we recurse into the impl and look at
135+
// the nested `?0: FnOnce(u32)` goal.
136+
if let Some(candidate) = inspect_goal.unique_applicable_candidate() {
137+
candidate.visit_nested_no_probe(self)
138+
}
139+
}
140+
}

compiler/rustc_middle/src/traits/solve/inspect.rs

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ pub struct Probe<'tcx> {
102102
/// What happened inside of this probe in chronological order.
103103
pub steps: Vec<ProbeStep<'tcx>>,
104104
pub kind: ProbeKind<'tcx>,
105+
pub final_state: CanonicalState<'tcx, ()>,
105106
}
106107

107108
impl Debug for Probe<'_> {

compiler/rustc_trait_selection/src/solve/assembly/mod.rs

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Code shared by trait and projection goals for candidate assembly.
22
3-
use super::{EvalCtxt, SolverMode};
43
use crate::solve::GoalSource;
4+
use crate::solve::{inspect, EvalCtxt, SolverMode};
55
use crate::traits::coherence;
66
use rustc_hir::def_id::DefId;
77
use rustc_infer::traits::query::NoSolution;
@@ -16,6 +16,7 @@ use rustc_middle::ty::{fast_reject, TypeFoldable};
1616
use rustc_middle::ty::{ToPredicate, TypeVisitableExt};
1717
use rustc_span::{ErrorGuaranteed, DUMMY_SP};
1818
use std::fmt::Debug;
19+
use std::mem;
1920

2021
pub(super) mod structural_traits;
2122

@@ -315,20 +316,17 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
315316
}
316317

317318
fn forced_ambiguity(&mut self, cause: MaybeCause) -> Vec<Candidate<'tcx>> {
318-
let source = CandidateSource::BuiltinImpl(BuiltinImplSource::Misc);
319-
let certainty = Certainty::Maybe(cause);
320319
// This may fail if `try_evaluate_added_goals` overflows because it
321320
// fails to reach a fixpoint but ends up getting an error after
322321
// running for some additional step.
323322
//
324-
// FIXME: Add a test for this. It seems to be necessary for typenum but
325-
// is incredibly hard to minimize as it may rely on being inside of a
326-
// trait solver cycle.
327-
let result = self.evaluate_added_goals_and_make_canonical_response(certainty);
328-
let mut dummy_probe = self.inspect.new_probe();
329-
dummy_probe.probe_kind(ProbeKind::TraitCandidate { source, result });
330-
self.inspect.finish_probe(dummy_probe);
331-
if let Ok(result) = result { vec![Candidate { source, result }] } else { vec![] }
323+
// cc trait-system-refactor-initiative#105
324+
let source = CandidateSource::BuiltinImpl(BuiltinImplSource::Misc);
325+
let certainty = Certainty::Maybe(cause);
326+
let result = self
327+
.probe_trait_candidate(source)
328+
.enter(|this| this.evaluate_added_goals_and_make_canonical_response(certainty));
329+
if let Ok(cand) = result { vec![cand] } else { vec![] }
332330
}
333331

334332
#[instrument(level = "debug", skip_all)]
@@ -813,6 +811,11 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
813811
goal: Goal<'tcx, G>,
814812
candidates: &mut Vec<Candidate<'tcx>>,
815813
) {
814+
// HACK: We temporarily remove the `ProofTreeBuilder` to
815+
// avoid adding `Trait` candidates to the candidates used
816+
// to prove the current goal.
817+
let inspect = mem::replace(&mut self.inspect, inspect::ProofTreeBuilder::new_noop());
818+
816819
let tcx = self.tcx();
817820
let trait_goal: Goal<'tcx, ty::TraitPredicate<'tcx>> =
818821
goal.with(tcx, goal.predicate.trait_ref(tcx));
@@ -846,6 +849,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
846849
}
847850
}
848851
}
852+
self.inspect = inspect;
849853
}
850854

851855
/// If there are multiple ways to prove a trait or projection goal, we have

0 commit comments

Comments
 (0)