Skip to content

Commit 169955f

Browse files
Properly drain pending obligations for coroutines
1 parent 67df5b9 commit 169955f

File tree

19 files changed

+241
-63
lines changed

19 files changed

+241
-63
lines changed

compiler/rustc_hir_typeck/src/closure.rs

+3
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
163163
// Resume type defaults to `()` if the coroutine has no argument.
164164
let resume_ty = liberated_sig.inputs().get(0).copied().unwrap_or(tcx.types.unit);
165165

166+
// TODO: In the new solver, we can just instantiate this eagerly
167+
// with the witness. This will ensure that goals that don't need
168+
// to stall on interior types will get processed eagerly.
166169
let interior = self.next_ty_var(expr_span);
167170
self.deferred_coroutine_interiors.borrow_mut().push((expr_def_id, interior));
168171

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
659659
obligations.extend(ok.obligations);
660660
}
661661

662-
// FIXME: Use a real visitor for unstalled obligations in the new solver.
663662
if !coroutines.is_empty() {
664-
obligations
665-
.extend(self.fulfillment_cx.borrow_mut().drain_unstalled_obligations(&self.infcx));
663+
obligations.extend(
664+
self.fulfillment_cx
665+
.borrow_mut()
666+
.drain_stalled_obligations_for_coroutines(&self.infcx),
667+
);
666668
}
667669

668670
self.typeck_results

compiler/rustc_hir_typeck/src/typeck_root_ctxt.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl<'tcx> TypeckRootCtxt<'tcx> {
8484
let hir_owner = tcx.local_def_id_to_hir_id(def_id).owner;
8585

8686
let infcx =
87-
tcx.infer_ctxt().ignoring_regions().build(TypingMode::analysis_in_body(tcx, def_id));
87+
tcx.infer_ctxt().ignoring_regions().build(TypingMode::typeck_for_body(tcx, def_id));
8888
let typeck_results = RefCell::new(ty::TypeckResults::new(hir_owner));
8989

9090
TypeckRootCtxt {

compiler/rustc_infer/src/infer/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ impl<'tcx> InferCtxt<'tcx> {
967967
pub fn can_define_opaque_ty(&self, id: impl Into<DefId>) -> bool {
968968
debug_assert!(!self.next_trait_solver());
969969
match self.typing_mode() {
970-
TypingMode::Analysis { defining_opaque_types }
970+
TypingMode::Analysis { defining_opaque_types, stalled_generators: _ }
971971
| TypingMode::Borrowck { defining_opaque_types } => {
972972
id.into().as_local().is_some_and(|def_id| defining_opaque_types.contains(&def_id))
973973
}
@@ -1262,7 +1262,7 @@ impl<'tcx> InferCtxt<'tcx> {
12621262
// to handle them without proper canonicalization. This means we may cause cycle
12631263
// errors and fail to reveal opaques while inside of bodies. We should rename this
12641264
// function and require explicit comments on all use-sites in the future.
1265-
ty::TypingMode::Analysis { defining_opaque_types: _ }
1265+
ty::TypingMode::Analysis { defining_opaque_types: _, stalled_generators: _ }
12661266
| ty::TypingMode::Borrowck { defining_opaque_types: _ } => {
12671267
TypingMode::non_body_analysis()
12681268
}

compiler/rustc_infer/src/traits/engine.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub trait TraitEngine<'tcx, E: 'tcx>: 'tcx {
9494
/// Among all pending obligations, collect those are stalled on a inference variable which has
9595
/// changed since the last call to `select_where_possible`. Those obligations are marked as
9696
/// successful and returned.
97-
fn drain_unstalled_obligations(
97+
fn drain_stalled_obligations_for_coroutines(
9898
&mut self,
9999
infcx: &InferCtxt<'tcx>,
100100
) -> PredicateObligations<'tcx>;

compiler/rustc_middle/src/query/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,15 @@ rustc_queries! {
387387
}
388388
}
389389

390+
query stalled_generators_within(
391+
key: LocalDefId
392+
) -> &'tcx ty::List<LocalDefId> {
393+
desc {
394+
|tcx| "computing the opaque types defined by `{}`",
395+
tcx.def_path_str(key.to_def_id())
396+
}
397+
}
398+
390399
/// Returns the explicitly user-written *bounds* on the associated or opaque type given by `DefId`
391400
/// that must be proven true at definition site (and which can be assumed at usage sites).
392401
///

compiler/rustc_middle/src/query/plumbing.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,11 @@ macro_rules! define_callbacks {
366366

367367
pub type Storage<'tcx> = <$($K)* as keys::Key>::Cache<Erase<$V>>;
368368

369-
// Ensure that keys grow no larger than 80 bytes by accident.
369+
// Ensure that keys grow no larger than 96 bytes by accident.
370370
// Increase this limit if necessary, but do try to keep the size low if possible
371371
#[cfg(target_pointer_width = "64")]
372372
const _: () = {
373-
if size_of::<Key<'static>>() > 88 {
373+
if size_of::<Key<'static>>() > 96 {
374374
panic!("{}", concat!(
375375
"the query `",
376376
stringify!($name),

compiler/rustc_middle/src/ty/context.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
106106
) -> Self::PredefinedOpaques {
107107
self.mk_predefined_opaques_in_body(data)
108108
}
109-
type DefiningOpaqueTypes = &'tcx ty::List<LocalDefId>;
109+
type LocalDefIds = &'tcx ty::List<LocalDefId>;
110110
type CanonicalVars = CanonicalVarInfos<'tcx>;
111111
fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars {
112112
self.mk_canonical_var_infos(infos)
@@ -674,9 +674,13 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
674674
self.anonymize_bound_vars(binder)
675675
}
676676

677-
fn opaque_types_defined_by(self, defining_anchor: LocalDefId) -> Self::DefiningOpaqueTypes {
677+
fn opaque_types_defined_by(self, defining_anchor: LocalDefId) -> Self::LocalDefIds {
678678
self.opaque_types_defined_by(defining_anchor)
679679
}
680+
681+
fn stalled_generators_within(self, defining_anchor: Self::LocalDefId) -> Self::LocalDefIds {
682+
self.stalled_generators_within(defining_anchor)
683+
}
680684
}
681685

682686
macro_rules! bidirectional_lang_item_map {

compiler/rustc_next_trait_solver/src/solve/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,10 @@ where
329329
TypingMode::Coherence | TypingMode::PostAnalysis => false,
330330
// During analysis, opaques are rigid unless they may be defined by
331331
// the current body.
332-
TypingMode::Analysis { defining_opaque_types: non_rigid_opaques }
332+
TypingMode::Analysis {
333+
defining_opaque_types: non_rigid_opaques,
334+
stalled_generators: _,
335+
}
333336
| TypingMode::Borrowck { defining_opaque_types: non_rigid_opaques }
334337
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: non_rigid_opaques } => {
335338
!def_id.as_local().is_some_and(|def_id| non_rigid_opaques.contains(&def_id))

compiler/rustc_next_trait_solver/src/solve/normalizes_to/opaque_types.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ where
3333
);
3434
self.evaluate_added_goals_and_make_canonical_response(Certainty::AMBIGUOUS)
3535
}
36-
TypingMode::Analysis { defining_opaque_types } => {
36+
TypingMode::Analysis { defining_opaque_types, stalled_generators: _ } => {
3737
let Some(def_id) = opaque_ty
3838
.def_id
3939
.as_local()

compiler/rustc_next_trait_solver/src/solve/trait_goals.rs

+15
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,21 @@ where
208208
}
209209
}
210210

211+
if let ty::CoroutineWitness(def_id, _) = goal.predicate.self_ty().kind() {
212+
match ecx.typing_mode() {
213+
TypingMode::Analysis { stalled_generators, defining_opaque_types: _ } => {
214+
if def_id.as_local().is_some_and(|def_id| stalled_generators.contains(&def_id))
215+
{
216+
return ecx.forced_ambiguity(MaybeCause::Ambiguity);
217+
}
218+
}
219+
TypingMode::Coherence
220+
| TypingMode::PostAnalysis
221+
| TypingMode::Borrowck { defining_opaque_types: _ }
222+
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ } => {}
223+
}
224+
}
225+
211226
ecx.probe_and_evaluate_goal_for_constituent_tys(
212227
CandidateSource::BuiltinImpl(BuiltinImplSource::Misc),
213228
goal,

compiler/rustc_trait_selection/src/solve/fulfill.rs

+87-13
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
use std::marker::PhantomData;
22
use std::mem;
3+
use std::ops::ControlFlow;
34

45
use rustc_data_structures::thinvec::ExtractIf;
6+
use rustc_hir::def_id::LocalDefId;
57
use rustc_infer::infer::InferCtxt;
68
use rustc_infer::traits::query::NoSolution;
79
use rustc_infer::traits::{
810
FromSolverError, PredicateObligation, PredicateObligations, TraitEngine,
911
};
12+
use rustc_middle::ty::{
13+
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitor, TypingMode,
14+
};
1015
use rustc_next_trait_solver::solve::{GenerateProofTree, HasChanged, SolverDelegateEvalExt as _};
16+
use rustc_span::Span;
1117
use tracing::instrument;
1218

1319
use self::derive_errors::*;
1420
use super::Certainty;
1521
use super::delegate::SolverDelegate;
22+
use super::inspect::{self, ProofTreeInferCtxtExt};
1623
use crate::traits::{FulfillmentError, ScrubbedTraitError};
1724

1825
mod derive_errors;
@@ -39,7 +46,7 @@ pub struct FulfillmentCtxt<'tcx, E: 'tcx> {
3946
_errors: PhantomData<E>,
4047
}
4148

42-
#[derive(Default)]
49+
#[derive(Default, Debug)]
4350
struct ObligationStorage<'tcx> {
4451
/// Obligations which resulted in an overflow in fulfillment itself.
4552
///
@@ -55,20 +62,23 @@ impl<'tcx> ObligationStorage<'tcx> {
5562
self.pending.push(obligation);
5663
}
5764

65+
fn has_pending_obligations(&self) -> bool {
66+
!self.pending.is_empty() || !self.overflowed.is_empty()
67+
}
68+
5869
fn clone_pending(&self) -> PredicateObligations<'tcx> {
5970
let mut obligations = self.pending.clone();
6071
obligations.extend(self.overflowed.iter().cloned());
6172
obligations
6273
}
6374

64-
fn take_pending(&mut self) -> PredicateObligations<'tcx> {
65-
let mut obligations = mem::take(&mut self.pending);
66-
obligations.append(&mut self.overflowed);
67-
obligations
68-
}
69-
70-
fn unstalled_for_select(&mut self) -> impl Iterator<Item = PredicateObligation<'tcx>> + 'tcx {
71-
mem::take(&mut self.pending).into_iter()
75+
fn drain_pending(
76+
&mut self,
77+
cond: impl Fn(&PredicateObligation<'tcx>) -> bool,
78+
) -> PredicateObligations<'tcx> {
79+
let (unstalled, pending) = mem::take(&mut self.pending).into_iter().partition(cond);
80+
self.pending = pending;
81+
unstalled
7282
}
7383

7484
fn on_fulfillment_overflow(&mut self, infcx: &InferCtxt<'tcx>) {
@@ -160,7 +170,7 @@ where
160170
}
161171

162172
let mut has_changed = false;
163-
for obligation in self.obligations.unstalled_for_select() {
173+
for obligation in self.obligations.drain_pending(|_| true) {
164174
let goal = obligation.as_goal();
165175
let result = <&SolverDelegate<'tcx>>::from(infcx)
166176
.evaluate_root_goal(goal, GenerateProofTree::No, obligation.cause.span)
@@ -196,15 +206,79 @@ where
196206
}
197207

198208
fn has_pending_obligations(&self) -> bool {
199-
!self.obligations.pending.is_empty() || !self.obligations.overflowed.is_empty()
209+
self.obligations.has_pending_obligations()
200210
}
201211

202212
fn pending_obligations(&self) -> PredicateObligations<'tcx> {
203213
self.obligations.clone_pending()
204214
}
205215

206-
fn drain_unstalled_obligations(&mut self, _: &InferCtxt<'tcx>) -> PredicateObligations<'tcx> {
207-
self.obligations.take_pending()
216+
fn drain_stalled_obligations_for_coroutines(
217+
&mut self,
218+
infcx: &InferCtxt<'tcx>,
219+
) -> PredicateObligations<'tcx> {
220+
self.obligations.drain_pending(|obl| {
221+
let stalled_generators = match infcx.typing_mode() {
222+
TypingMode::Analysis { defining_opaque_types: _, stalled_generators } => {
223+
stalled_generators
224+
}
225+
TypingMode::Coherence
226+
| TypingMode::Borrowck { defining_opaque_types: _ }
227+
| TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
228+
| TypingMode::PostAnalysis => return false,
229+
};
230+
231+
if stalled_generators.is_empty() {
232+
return false;
233+
}
234+
235+
infcx.probe(|_| {
236+
infcx
237+
.visit_proof_tree(
238+
obl.as_goal(),
239+
&mut StalledOnCoroutines { stalled_generators, span: obl.cause.span },
240+
)
241+
.is_break()
242+
})
243+
})
244+
}
245+
}
246+
247+
struct StalledOnCoroutines<'tcx> {
248+
stalled_generators: &'tcx ty::List<LocalDefId>,
249+
span: Span,
250+
// TODO: Cache
251+
}
252+
253+
impl<'tcx> inspect::ProofTreeVisitor<'tcx> for StalledOnCoroutines<'tcx> {
254+
type Result = ControlFlow<()>;
255+
256+
fn span(&self) -> rustc_span::Span {
257+
self.span
258+
}
259+
260+
fn visit_goal(&mut self, inspect_goal: &super::inspect::InspectGoal<'_, 'tcx>) -> Self::Result {
261+
inspect_goal.goal().predicate.visit_with(self)?;
262+
263+
if let Some(candidate) = inspect_goal.unique_applicable_candidate() {
264+
candidate.visit_nested_no_probe(self)
265+
} else {
266+
ControlFlow::Continue(())
267+
}
268+
}
269+
}
270+
271+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for StalledOnCoroutines<'tcx> {
272+
type Result = ControlFlow<()>;
273+
274+
fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
275+
if let ty::CoroutineWitness(def_id, _) = *ty.kind()
276+
&& def_id.as_local().is_some_and(|def_id| self.stalled_generators.contains(&def_id))
277+
{
278+
return ControlFlow::Break(());
279+
}
280+
281+
ty.super_visit_with(self)
208282
}
209283
}
210284

compiler/rustc_trait_selection/src/solve/fulfill/derive_errors.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,16 @@ pub(super) fn fulfillment_error_for_stalled<'tcx>(
109109
false,
110110
),
111111
Ok((_, Certainty::Yes)) => {
112-
bug!("did not expect successful goal when collecting ambiguity errors")
112+
bug!(
113+
"did not expect successful goal when collecting ambiguity errors for `{:?}`",
114+
infcx.resolve_vars_if_possible(root_obligation.predicate),
115+
)
113116
}
114117
Err(_) => {
115-
bug!("did not expect selection error when collecting ambiguity errors")
118+
bug!(
119+
"did not expect selection error when collecting ambiguity errors for `{:?}`",
120+
infcx.resolve_vars_if_possible(root_obligation.predicate),
121+
)
116122
}
117123
}
118124
});

0 commit comments

Comments
 (0)