Skip to content

Commit c1d65ea

Browse files
committed
Auto merge of rust-lang#96892 - oli-obk:🐌_obligation_cause_code_🐌, r=estebank
Clean up derived obligation creation r? `@estebank` working on fixing the perf regression from rust-lang#91030 (comment)
2 parents c1cfdd1 + 0cefa5f commit c1d65ea

File tree

11 files changed

+176
-219
lines changed

11 files changed

+176
-219
lines changed

compiler/rustc_infer/src/traits/mod.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,21 @@ impl<'tcx> PredicateObligation<'tcx> {
6969
}
7070
}
7171

72-
impl TraitObligation<'_> {
72+
impl<'tcx> TraitObligation<'tcx> {
7373
/// Returns `true` if the trait predicate is considered `const` in its ParamEnv.
7474
pub fn is_const(&self) -> bool {
7575
match (self.predicate.skip_binder().constness, self.param_env.constness()) {
7676
(ty::BoundConstness::ConstIfConst, hir::Constness::Const) => true,
7777
_ => false,
7878
}
7979
}
80+
81+
pub fn derived_cause(
82+
&self,
83+
variant: impl FnOnce(DerivedObligationCause<'tcx>) -> ObligationCauseCode<'tcx>,
84+
) -> ObligationCause<'tcx> {
85+
self.cause.clone().derived_cause(self.predicate, variant)
86+
}
8087
}
8188

8289
// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.

compiler/rustc_middle/src/traits/mod.rs

+75-35
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ pub struct ObligationCause<'tcx> {
9797
/// information.
9898
pub body_id: hir::HirId,
9999

100-
/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
101-
/// the time). `Some` otherwise.
102-
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
100+
code: InternedObligationCauseCode<'tcx>,
103101
}
104102

105103
// This custom hash function speeds up hashing for `Obligation` deduplication
@@ -123,11 +121,7 @@ impl<'tcx> ObligationCause<'tcx> {
123121
body_id: hir::HirId,
124122
code: ObligationCauseCode<'tcx>,
125123
) -> ObligationCause<'tcx> {
126-
ObligationCause {
127-
span,
128-
body_id,
129-
code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) },
130-
}
124+
ObligationCause { span, body_id, code: code.into() }
131125
}
132126

133127
pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
@@ -136,15 +130,12 @@ impl<'tcx> ObligationCause<'tcx> {
136130

137131
#[inline(always)]
138132
pub fn dummy() -> ObligationCause<'tcx> {
139-
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
133+
ObligationCause::dummy_with_span(DUMMY_SP)
140134
}
141135

136+
#[inline(always)]
142137
pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
143-
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None }
144-
}
145-
146-
pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> {
147-
Lrc::make_mut(self.code.get_or_insert_with(|| Lrc::new(MISC_OBLIGATION_CAUSE_CODE)))
138+
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: Default::default() }
148139
}
149140

150141
pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
@@ -164,14 +155,37 @@ impl<'tcx> ObligationCause<'tcx> {
164155

165156
#[inline]
166157
pub fn code(&self) -> &ObligationCauseCode<'tcx> {
167-
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
158+
&self.code
168159
}
169160

170-
pub fn clone_code(&self) -> Lrc<ObligationCauseCode<'tcx>> {
171-
match &self.code {
172-
Some(code) => code.clone(),
173-
None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE),
174-
}
161+
pub fn map_code(
162+
&mut self,
163+
f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> ObligationCauseCode<'tcx>,
164+
) {
165+
self.code = f(std::mem::take(&mut self.code)).into();
166+
}
167+
168+
pub fn derived_cause(
169+
mut self,
170+
parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
171+
variant: impl FnOnce(DerivedObligationCause<'tcx>) -> ObligationCauseCode<'tcx>,
172+
) -> ObligationCause<'tcx> {
173+
/*!
174+
* Creates a cause for obligations that are derived from
175+
* `obligation` by a recursive search (e.g., for a builtin
176+
* bound, or eventually a `auto trait Foo`). If `obligation`
177+
* is itself a derived obligation, this is just a clone, but
178+
* otherwise we create a "derived obligation" cause so as to
179+
* keep track of the original root obligation for error
180+
* reporting.
181+
*/
182+
183+
// NOTE(flaper87): As of now, it keeps track of the whole error
184+
// chain. Ideally, we should have a way to configure this either
185+
// by using -Z verbose or just a CLI argument.
186+
self.code =
187+
variant(DerivedObligationCause { parent_trait_pred, parent_code: self.code }).into();
188+
self
175189
}
176190
}
177191

@@ -182,6 +196,30 @@ pub struct UnifyReceiverContext<'tcx> {
182196
pub substs: SubstsRef<'tcx>,
183197
}
184198

199+
#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift, Default)]
200+
pub struct InternedObligationCauseCode<'tcx> {
201+
/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
202+
/// the time). `Some` otherwise.
203+
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
204+
}
205+
206+
impl<'tcx> ObligationCauseCode<'tcx> {
207+
#[inline(always)]
208+
fn into(self) -> InternedObligationCauseCode<'tcx> {
209+
InternedObligationCauseCode {
210+
code: if let MISC_OBLIGATION_CAUSE_CODE = self { None } else { Some(Lrc::new(self)) },
211+
}
212+
}
213+
}
214+
215+
impl<'tcx> std::ops::Deref for InternedObligationCauseCode<'tcx> {
216+
type Target = ObligationCauseCode<'tcx>;
217+
218+
fn deref(&self) -> &Self::Target {
219+
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
220+
}
221+
}
222+
185223
#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
186224
pub enum ObligationCauseCode<'tcx> {
187225
/// Not well classified or should be obvious from the span.
@@ -269,7 +307,7 @@ pub enum ObligationCauseCode<'tcx> {
269307
/// The node of the function call.
270308
call_hir_id: hir::HirId,
271309
/// The obligation introduced by this argument.
272-
parent_code: Lrc<ObligationCauseCode<'tcx>>,
310+
parent_code: InternedObligationCauseCode<'tcx>,
273311
},
274312

275313
/// Error derived when matching traits/impls; see ObligationCause for more details
@@ -404,25 +442,27 @@ pub struct ImplDerivedObligationCause<'tcx> {
404442
pub span: Span,
405443
}
406444

407-
impl ObligationCauseCode<'_> {
445+
impl<'tcx> ObligationCauseCode<'tcx> {
408446
// Return the base obligation, ignoring derived obligations.
409447
pub fn peel_derives(&self) -> &Self {
410448
let mut base_cause = self;
411-
loop {
412-
match base_cause {
413-
BuiltinDerivedObligation(DerivedObligationCause { parent_code, .. })
414-
| DerivedObligation(DerivedObligationCause { parent_code, .. })
415-
| FunctionArgumentObligation { parent_code, .. } => {
416-
base_cause = &parent_code;
417-
}
418-
ImplDerivedObligation(obligation_cause) => {
419-
base_cause = &*obligation_cause.derived.parent_code;
420-
}
421-
_ => break,
422-
}
449+
while let Some((parent_code, _)) = base_cause.parent() {
450+
base_cause = parent_code;
423451
}
424452
base_cause
425453
}
454+
455+
pub fn parent(&self) -> Option<(&Self, Option<ty::PolyTraitPredicate<'tcx>>)> {
456+
match self {
457+
FunctionArgumentObligation { parent_code, .. } => Some((parent_code, None)),
458+
BuiltinDerivedObligation(derived)
459+
| DerivedObligation(derived)
460+
| ImplDerivedObligation(box ImplDerivedObligationCause { derived, .. }) => {
461+
Some((&derived.parent_code, Some(derived.parent_trait_pred)))
462+
}
463+
_ => None,
464+
}
465+
}
426466
}
427467

428468
// `ObligationCauseCode` is used a lot. Make sure it doesn't unintentionally get bigger.
@@ -472,7 +512,7 @@ pub struct DerivedObligationCause<'tcx> {
472512
pub parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
473513

474514
/// The parent trait had this cause.
475-
pub parent_code: Lrc<ObligationCauseCode<'tcx>>,
515+
pub parent_code: InternedObligationCauseCode<'tcx>,
476516
}
477517

478518
#[derive(Clone, Debug, TypeFoldable, Lift)]

compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs

+10-41
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ pub mod on_unimplemented;
22
pub mod suggestions;
33

44
use super::{
5-
DerivedObligationCause, EvaluationResult, FulfillmentContext, FulfillmentError,
6-
FulfillmentErrorCode, ImplDerivedObligationCause, MismatchedProjectionTypes, Obligation,
7-
ObligationCause, ObligationCauseCode, OnUnimplementedDirective, OnUnimplementedNote,
8-
OutputTypeParameterMismatch, Overflow, PredicateObligation, SelectionContext, SelectionError,
9-
TraitNotObjectSafe,
5+
EvaluationResult, FulfillmentContext, FulfillmentError, FulfillmentErrorCode,
6+
MismatchedProjectionTypes, Obligation, ObligationCause, ObligationCauseCode,
7+
OnUnimplementedDirective, OnUnimplementedNote, OutputTypeParameterMismatch, Overflow,
8+
PredicateObligation, SelectionContext, SelectionError, TraitNotObjectSafe,
109
};
1110

1211
use crate::infer::error_reporting::{TyCategory, TypeAnnotationNeeded as ErrorCode};
@@ -684,42 +683,12 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
684683
let mut code = obligation.cause.code();
685684
let mut trait_pred = trait_predicate;
686685
let mut peeled = false;
687-
loop {
688-
match &*code {
689-
ObligationCauseCode::FunctionArgumentObligation {
690-
parent_code,
691-
..
692-
} => {
693-
code = &parent_code;
694-
}
695-
ObligationCauseCode::ImplDerivedObligation(
696-
box ImplDerivedObligationCause {
697-
derived:
698-
DerivedObligationCause {
699-
parent_code,
700-
parent_trait_pred,
701-
},
702-
..
703-
},
704-
)
705-
| ObligationCauseCode::BuiltinDerivedObligation(
706-
DerivedObligationCause {
707-
parent_code,
708-
parent_trait_pred,
709-
},
710-
)
711-
| ObligationCauseCode::DerivedObligation(
712-
DerivedObligationCause {
713-
parent_code,
714-
parent_trait_pred,
715-
},
716-
) => {
717-
peeled = true;
718-
code = &parent_code;
719-
trait_pred = *parent_trait_pred;
720-
}
721-
_ => break,
722-
};
686+
while let Some((parent_code, parent_trait_pred)) = code.parent() {
687+
code = parent_code;
688+
if let Some(parent_trait_pred) = parent_trait_pred {
689+
trait_pred = parent_trait_pred;
690+
peeled = true;
691+
}
723692
}
724693
let def_id = trait_pred.def_id();
725694
// Mention *all* the `impl`s for the *top most* obligation, the

compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs

+12-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::{
2-
DerivedObligationCause, EvaluationResult, ImplDerivedObligationCause, Obligation,
3-
ObligationCause, ObligationCauseCode, PredicateObligation, SelectionContext,
2+
EvaluationResult, Obligation, ObligationCause, ObligationCauseCode, PredicateObligation,
3+
SelectionContext,
44
};
55

66
use crate::autoderef::Autoderef;
@@ -623,28 +623,11 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
623623
let span = obligation.cause.span;
624624
let mut real_trait_pred = trait_pred;
625625
let mut code = obligation.cause.code();
626-
loop {
627-
match &code {
628-
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => {
629-
code = &parent_code;
630-
}
631-
ObligationCauseCode::ImplDerivedObligation(box ImplDerivedObligationCause {
632-
derived: DerivedObligationCause { parent_code, parent_trait_pred },
633-
..
634-
})
635-
| ObligationCauseCode::BuiltinDerivedObligation(DerivedObligationCause {
636-
parent_code,
637-
parent_trait_pred,
638-
})
639-
| ObligationCauseCode::DerivedObligation(DerivedObligationCause {
640-
parent_code,
641-
parent_trait_pred,
642-
}) => {
643-
code = &parent_code;
644-
real_trait_pred = *parent_trait_pred;
645-
}
646-
_ => break,
647-
};
626+
while let Some((parent_code, parent_trait_pred)) = code.parent() {
627+
code = parent_code;
628+
if let Some(parent_trait_pred) = parent_trait_pred {
629+
real_trait_pred = parent_trait_pred;
630+
}
648631
let Some(real_ty) = real_trait_pred.self_ty().no_bound_vars() else {
649632
continue;
650633
};
@@ -1669,7 +1652,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
16691652
debug!("maybe_note_obligation_cause_for_async_await: code={:?}", code);
16701653
match code {
16711654
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => {
1672-
next_code = Some(parent_code.as_ref());
1655+
next_code = Some(parent_code);
16731656
}
16741657
ObligationCauseCode::ImplDerivedObligation(cause) => {
16751658
let ty = cause.derived.parent_trait_pred.skip_binder().self_ty();
@@ -1700,7 +1683,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
17001683
_ => {}
17011684
}
17021685

1703-
next_code = Some(cause.derived.parent_code.as_ref());
1686+
next_code = Some(&cause.derived.parent_code);
17041687
}
17051688
ObligationCauseCode::DerivedObligation(derived_obligation)
17061689
| ObligationCauseCode::BuiltinDerivedObligation(derived_obligation) => {
@@ -1732,7 +1715,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
17321715
_ => {}
17331716
}
17341717

1735-
next_code = Some(derived_obligation.parent_code.as_ref());
1718+
next_code = Some(&derived_obligation.parent_code);
17361719
}
17371720
_ => break,
17381721
}
@@ -2382,8 +2365,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
23822365
let is_upvar_tys_infer_tuple = if !matches!(ty.kind(), ty::Tuple(..)) {
23832366
false
23842367
} else {
2385-
if let ObligationCauseCode::BuiltinDerivedObligation(ref data) =
2386-
*data.parent_code
2368+
if let ObligationCauseCode::BuiltinDerivedObligation(data) = &*data.parent_code
23872369
{
23882370
let parent_trait_ref =
23892371
self.resolve_vars_if_possible(data.parent_trait_pred);
@@ -2428,7 +2410,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
24282410
err,
24292411
&parent_predicate,
24302412
param_env,
2431-
&cause_code.peel_derives(),
2413+
cause_code.peel_derives(),
24322414
obligated_types,
24332415
seen_requirements,
24342416
)

0 commit comments

Comments
 (0)