Skip to content

Commit 7ceaf52

Browse files
Support specializtion for RPITITs
1 parent b42a811 commit 7ceaf52

File tree

13 files changed

+163
-114
lines changed

13 files changed

+163
-114
lines changed

compiler/rustc_hir_analysis/src/check/compare_impl_item.rs

+21-7
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub(super) fn compare_impl_method<'tcx>(
5050
compare_generic_param_kinds(tcx, impl_m, trait_m, false)?;
5151
compare_number_of_method_arguments(tcx, impl_m, trait_m)?;
5252
compare_synthetic_generics(tcx, impl_m, trait_m)?;
53-
compare_asyncness(tcx, impl_m, trait_m)?;
53+
compare_asyncness(tcx, impl_m, trait_m, false)?;
5454
compare_method_predicate_entailment(
5555
tcx,
5656
impl_m,
@@ -191,6 +191,11 @@ fn compare_method_predicate_entailment<'tcx>(
191191
.map(|(predicate, _)| predicate),
192192
);
193193

194+
// Additionally, we are allowed to assume that we can project RPITITs to their
195+
// associated hidden types within method signatures. This is to allow us to support
196+
// specialization with `impl Trait` in traits.
197+
hybrid_preds.predicates.extend(tcx.additional_method_assumptions(impl_m_def_id));
198+
194199
// Construct trait parameter environment and then shift it into the placeholder viewpoint.
195200
// The key step here is to update the caller_bounds's predicates to be
196201
// the new hybrid bounds we computed.
@@ -526,6 +531,7 @@ fn compare_asyncness<'tcx>(
526531
tcx: TyCtxt<'tcx>,
527532
impl_m: ty::AssocItem,
528533
trait_m: ty::AssocItem,
534+
delay: bool,
529535
) -> Result<(), ErrorGuaranteed> {
530536
if tcx.asyncness(trait_m.def_id) == hir::IsAsync::Async {
531537
match tcx.fn_sig(impl_m.def_id).skip_binder().skip_binder().output().kind() {
@@ -536,11 +542,14 @@ fn compare_asyncness<'tcx>(
536542
// We don't know if it's ok, but at least it's already an error.
537543
}
538544
_ => {
539-
return Err(tcx.sess.emit_err(crate::errors::AsyncTraitImplShouldBeAsync {
540-
span: tcx.def_span(impl_m.def_id),
541-
method_name: trait_m.name,
542-
trait_item_span: tcx.hir().span_if_local(trait_m.def_id),
543-
}));
545+
return Err(tcx
546+
.sess
547+
.create_err(crate::errors::AsyncTraitImplShouldBeAsync {
548+
span: tcx.def_span(impl_m.def_id),
549+
method_name: trait_m.name,
550+
trait_item_span: tcx.hir().span_if_local(trait_m.def_id),
551+
})
552+
.emit_unless(delay));
544553
}
545554
};
546555
}
@@ -590,10 +599,15 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
590599
let trait_m = tcx.opt_associated_item(impl_m.trait_item_def_id.unwrap()).unwrap();
591600
let impl_trait_ref =
592601
tcx.impl_trait_ref(impl_m.impl_container(tcx).unwrap()).unwrap().subst_identity();
593-
let param_env = tcx.param_env(def_id);
602+
603+
// We use the RPITIT values computed in this method to construct the param-env,
604+
// so to avoid cycles, we do computations in this function without assuming anything
605+
// about RPITIT projection.
606+
let param_env = tcx.param_env_no_assumptions(def_id);
594607

595608
// First, check a few of the same things as `compare_impl_method`,
596609
// just so we don't ICE during substitution later.
610+
compare_asyncness(tcx, impl_m, trait_m, true)?;
597611
compare_number_of_generics(tcx, impl_m, trait_m, true)?;
598612
compare_generic_param_kinds(tcx, impl_m, trait_m, true)?;
599613
check_region_bounds_on_impl_item(tcx, impl_m, trait_m, true)?;

compiler/rustc_metadata/src/rmeta/encoder.rs

+1-29
Original file line numberDiff line numberDiff line change
@@ -1101,34 +1101,6 @@ fn should_encode_const(def_kind: DefKind) -> bool {
11011101
}
11021102
}
11031103

1104-
fn should_encode_trait_impl_trait_tys(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
1105-
if tcx.def_kind(def_id) != DefKind::AssocFn {
1106-
return false;
1107-
}
1108-
1109-
let Some(item) = tcx.opt_associated_item(def_id) else { return false; };
1110-
if item.container != ty::AssocItemContainer::ImplContainer {
1111-
return false;
1112-
}
1113-
1114-
let Some(trait_item_def_id) = item.trait_item_def_id else { return false; };
1115-
1116-
// FIXME(RPITIT): This does a somewhat manual walk through the signature
1117-
// of the trait fn to look for any RPITITs, but that's kinda doing a lot
1118-
// of work. We can probably remove this when we refactor RPITITs to be
1119-
// associated types.
1120-
tcx.fn_sig(trait_item_def_id).subst_identity().skip_binder().output().walk().any(|arg| {
1121-
if let ty::GenericArgKind::Type(ty) = arg.unpack()
1122-
&& let ty::Alias(ty::Projection, data) = ty.kind()
1123-
&& tcx.def_kind(data.def_id) == DefKind::ImplTraitPlaceholder
1124-
{
1125-
true
1126-
} else {
1127-
false
1128-
}
1129-
})
1130-
}
1131-
11321104
// Return `false` to avoid encoding impl trait in trait, while we don't use the query.
11331105
fn should_encode_fn_impl_trait_in_trait<'tcx>(_tcx: TyCtxt<'tcx>, _def_id: DefId) -> bool {
11341106
false
@@ -1211,7 +1183,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
12111183
if let DefKind::Enum | DefKind::Struct | DefKind::Union = def_kind {
12121184
self.encode_info_for_adt(def_id);
12131185
}
1214-
if should_encode_trait_impl_trait_tys(tcx, def_id)
1186+
if tcx.impl_method_has_trait_impl_trait_tys(def_id)
12151187
&& let Ok(table) = self.tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
12161188
{
12171189
record!(self.tables.trait_impl_trait_tys[def_id] <- table);

compiler/rustc_middle/src/query/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,14 @@ rustc_queries! {
13131313
desc { |tcx| "computing normalized predicates of `{}`", tcx.def_path_str(def_id) }
13141314
}
13151315

1316+
query param_env_no_assumptions(def_id: DefId) -> ty::ParamEnv<'tcx> {
1317+
desc { |tcx| "computing normalized predicates of `{}`", tcx.def_path_str(def_id) }
1318+
}
1319+
1320+
query additional_method_assumptions(def_id: DefId) -> &'tcx ty::List<ty::Predicate<'tcx>> {
1321+
desc { |tcx| "computing additional predicate assumptions for the body of `{}`", tcx.def_path_str(def_id) }
1322+
}
1323+
13161324
/// Like `param_env`, but returns the `ParamEnv` in `Reveal::All` mode.
13171325
/// Prefer this over `tcx.param_env(def_id).with_reveal_all_normalized(tcx)`,
13181326
/// as this method is more efficient.

compiler/rustc_middle/src/ty/mod.rs

+28
Original file line numberDiff line numberDiff line change
@@ -2544,6 +2544,34 @@ impl<'tcx> TyCtxt<'tcx> {
25442544
}
25452545
def_id
25462546
}
2547+
2548+
pub fn impl_method_has_trait_impl_trait_tys(self, def_id: DefId) -> bool {
2549+
if self.def_kind(def_id) != DefKind::AssocFn {
2550+
return false;
2551+
}
2552+
2553+
let Some(item) = self.opt_associated_item(def_id) else { return false; };
2554+
if item.container != ty::AssocItemContainer::ImplContainer {
2555+
return false;
2556+
}
2557+
2558+
let Some(trait_item_def_id) = item.trait_item_def_id else { return false; };
2559+
2560+
// FIXME(RPITIT): This does a somewhat manual walk through the signature
2561+
// of the trait fn to look for any RPITITs, but that's kinda doing a lot
2562+
// of work. We can probably remove this when we refactor RPITITs to be
2563+
// associated types.
2564+
self.fn_sig(trait_item_def_id).subst_identity().skip_binder().output().walk().any(|arg| {
2565+
if let ty::GenericArgKind::Type(ty) = arg.unpack()
2566+
&& let ty::Alias(ty::Projection, data) = ty.kind()
2567+
&& self.def_kind(data.def_id) == DefKind::ImplTraitPlaceholder
2568+
{
2569+
true
2570+
} else {
2571+
false
2572+
}
2573+
})
2574+
}
25472575
}
25482576

25492577
/// Yields the parent function's `LocalDefId` if `def_id` is an `impl Trait` definition.

compiler/rustc_middle/src/ty/subst.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ impl<'tcx> InternalSubsts<'tcx> {
468468
target_substs: SubstsRef<'tcx>,
469469
) -> SubstsRef<'tcx> {
470470
let defs = tcx.generics_of(source_ancestor);
471-
tcx.mk_substs(target_substs.iter().chain(self.iter().skip(defs.params.len())))
471+
tcx.mk_substs(target_substs.iter().chain(self.iter().skip(defs.count())))
472472
}
473473

474474
pub fn truncate_to(&self, tcx: TyCtxt<'tcx>, generics: &ty::Generics) -> SubstsRef<'tcx> {

compiler/rustc_trait_selection/src/traits/project.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1246,13 +1246,13 @@ fn project<'cx, 'tcx>(
12461246

12471247
let mut candidates = ProjectionCandidateSet::None;
12481248

1249-
assemble_candidate_for_impl_trait_in_trait(selcx, obligation, &mut candidates);
1250-
12511249
// Make sure that the following procedures are kept in order. ParamEnv
12521250
// needs to be first because it has highest priority, and Select checks
12531251
// the return value of push_candidate which assumes it's ran at last.
12541252
assemble_candidates_from_param_env(selcx, obligation, &mut candidates);
12551253

1254+
assemble_candidate_for_impl_trait_in_trait(selcx, obligation, &mut candidates);
1255+
12561256
assemble_candidates_from_trait_def(selcx, obligation, &mut candidates);
12571257

12581258
assemble_candidates_from_object_ty(selcx, obligation, &mut candidates);

compiler/rustc_ty_utils/src/ty.rs

+73-16
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ fn adt_sized_constraint(tcx: TyCtxt<'_>, def_id: DefId) -> &[Ty<'_>] {
116116
}
117117

118118
/// See `ParamEnv` struct definition for details.
119-
fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
119+
fn param_env(tcx: TyCtxt<'_>, def_id: DefId, add_assumptions: bool) -> ty::ParamEnv<'_> {
120120
// Compute the bounds on Self and the type parameters.
121121
let ty::InstantiatedPredicates { mut predicates, .. } =
122122
tcx.predicates_of(def_id).instantiate_identity(tcx);
@@ -138,17 +138,8 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
138138
predicates.extend(environment);
139139
}
140140

141-
if tcx.def_kind(def_id) == DefKind::AssocFn
142-
&& tcx.associated_item(def_id).container == ty::AssocItemContainer::TraitContainer
143-
{
144-
let sig = tcx.fn_sig(def_id).subst_identity();
145-
sig.visit_with(&mut ImplTraitInTraitFinder {
146-
tcx,
147-
fn_def_id: def_id,
148-
bound_vars: sig.bound_vars(),
149-
predicates: &mut predicates,
150-
seen: FxHashSet::default(),
151-
});
141+
if add_assumptions && tcx.def_kind(def_id) == DefKind::AssocFn {
142+
predicates.extend(tcx.additional_method_assumptions(def_id))
152143
}
153144

154145
let local_did = def_id.as_local();
@@ -237,19 +228,83 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
237228
traits::normalize_param_env_or_error(tcx, unnormalized_env, cause)
238229
}
239230

231+
fn additional_method_assumptions<'tcx>(
232+
tcx: TyCtxt<'tcx>,
233+
def_id: DefId,
234+
) -> &'tcx ty::List<Predicate<'tcx>> {
235+
let assoc_item = tcx.associated_item(def_id);
236+
let mut predicates = vec![];
237+
238+
match assoc_item.container {
239+
ty::AssocItemContainer::TraitContainer => {
240+
let sig = tcx.fn_sig(def_id).subst_identity();
241+
sig.visit_with(&mut ImplTraitInTraitFinder {
242+
tcx,
243+
fn_def_id: def_id,
244+
bound_vars: sig.bound_vars(),
245+
predicates: &mut predicates,
246+
seen: FxHashSet::default(),
247+
hidden_ty: |alias_ty| tcx.mk_alias(ty::Opaque, alias_ty),
248+
});
249+
}
250+
ty::AssocItemContainer::ImplContainer => {
251+
if tcx.impl_method_has_trait_impl_trait_tys(def_id)
252+
&& let Ok(table)
253+
= tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
254+
{
255+
let impl_def_id = assoc_item.container_id(tcx);
256+
let trait_to_impl_substs =
257+
tcx.impl_trait_ref(impl_def_id).unwrap().subst_identity().substs;
258+
// Create mapping from impl to placeholder.
259+
let impl_to_placeholder_substs = ty::InternalSubsts::identity_for_item(tcx, def_id);
260+
// Create mapping from trait to placeholder.
261+
let trait_to_placeholder_substs =
262+
impl_to_placeholder_substs.rebase_onto(tcx, impl_def_id, trait_to_impl_substs);
263+
264+
let trait_fn_def_id = assoc_item.trait_item_def_id.unwrap();
265+
let trait_fn_sig =
266+
tcx.fn_sig(trait_fn_def_id).subst(tcx, trait_to_placeholder_substs);
267+
trait_fn_sig.visit_with(&mut ImplTraitInTraitFinder {
268+
tcx,
269+
fn_def_id: trait_fn_def_id,
270+
bound_vars: trait_fn_sig.bound_vars(),
271+
predicates: &mut predicates,
272+
seen: FxHashSet::default(),
273+
hidden_ty: |alias_ty| {
274+
EarlyBinder(*table.get(&alias_ty.def_id).unwrap()).subst(
275+
tcx,
276+
alias_ty.substs.rebase_onto(
277+
tcx,
278+
trait_fn_def_id,
279+
impl_to_placeholder_substs,
280+
),
281+
)
282+
},
283+
});
284+
}
285+
}
286+
}
287+
288+
tcx.intern_predicates(&predicates)
289+
}
290+
240291
/// Walk through a function type, gathering all RPITITs and installing a
241292
/// `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` predicate into the
242293
/// predicates list. This allows us to observe that an RPITIT projects to
243294
/// its corresponding opaque within the body of a default-body trait method.
244-
struct ImplTraitInTraitFinder<'a, 'tcx> {
295+
struct ImplTraitInTraitFinder<'a, 'tcx, F: Fn(ty::AliasTy<'tcx>) -> Ty<'tcx>> {
245296
tcx: TyCtxt<'tcx>,
246297
predicates: &'a mut Vec<Predicate<'tcx>>,
247298
fn_def_id: DefId,
248299
bound_vars: &'tcx ty::List<ty::BoundVariableKind>,
249300
seen: FxHashSet<DefId>,
301+
hidden_ty: F,
250302
}
251303

252-
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
304+
impl<'tcx, F> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx, F>
305+
where
306+
F: Fn(ty::AliasTy<'tcx>) -> Ty<'tcx>,
307+
{
253308
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
254309
if let ty::Alias(ty::Projection, alias_ty) = *ty.kind()
255310
&& self.tcx.def_kind(alias_ty.def_id) == DefKind::ImplTraitPlaceholder
@@ -260,7 +315,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
260315
ty::Binder::bind_with_vars(
261316
ty::ProjectionPredicate {
262317
projection_ty: alias_ty,
263-
term: self.tcx.mk_alias(ty::Opaque, alias_ty).into(),
318+
term: (self.hidden_ty)(alias_ty).into(),
264319
},
265320
self.bound_vars,
266321
)
@@ -514,7 +569,9 @@ pub fn provide(providers: &mut ty::query::Providers) {
514569
*providers = ty::query::Providers {
515570
asyncness,
516571
adt_sized_constraint,
517-
param_env,
572+
param_env: |tcx, def_id| param_env(tcx, def_id, true),
573+
param_env_no_assumptions: |tcx, def_id| param_env(tcx, def_id, false),
574+
additional_method_assumptions,
518575
param_env_reveal_all_normalized,
519576
instance_def_size_estimate,
520577
issue33140_self_ty,

tests/ui/async-await/in-trait/dont-project-to-specializable-projection.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// edition: 2021
2-
// known-bug: #108309
2+
// check-pass
33

44
#![feature(async_fn_in_trait)]
5+
//~^ WARN the feature `async_fn_in_trait` is incomplete
56
#![feature(min_specialization)]
67

78
struct MyStruct;

tests/ui/async-await/in-trait/dont-project-to-specializable-projection.stderr

+1-16
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,5 @@ LL | #![feature(async_fn_in_trait)]
77
= note: see issue #91611 <https://github.com/rust-lang/rust/issues/91611> for more information
88
= note: `#[warn(incomplete_features)]` on by default
99

10-
error[E0053]: method `foo` has an incompatible type for trait
11-
--> $DIR/dont-project-to-specializable-projection.rs:14:35
12-
|
13-
LL | default async fn foo(_: T) -> &'static str {
14-
| ^^^^^^^^^^^^ expected associated type, found future
15-
|
16-
note: type in trait
17-
--> $DIR/dont-project-to-specializable-projection.rs:10:27
18-
|
19-
LL | async fn foo(_: T) -> &'static str;
20-
| ^^^^^^^^^^^^
21-
= note: expected signature `fn(_) -> impl Future<Output = &'static str>`
22-
found signature `fn(_) -> impl Future<Output = &'static str>`
23-
24-
error: aborting due to previous error; 1 warning emitted
10+
warning: 1 warning emitted
2511

26-
For more information about this error, try `rustc --explain E0053`.

tests/ui/impl-trait/in-trait/method-signature-matches.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ trait TooMuch {
2727

2828
impl TooMuch for () {
2929
fn calm_down_please(_: (), _: (), _: ()) {}
30-
//~^ ERROR method `calm_down_please` has 3 parameters but the declaration in trait `TooMuch::calm_down_please` has 0
30+
//~^ ERROR method `calm_down_please` has an incompatible type for trait
3131
}
3232

3333
trait TooLittle {
@@ -36,7 +36,7 @@ trait TooLittle {
3636

3737
impl TooLittle for () {
3838
fn come_on_a_little_more_effort() {}
39-
//~^ ERROR method `come_on_a_little_more_effort` has 0 parameters but the declaration in trait `TooLittle::come_on_a_little_more_effort` has 3
39+
//~^ ERROR method `come_on_a_little_more_effort` has an incompatible type for trait
4040
}
4141

4242
trait Lifetimes {

0 commit comments

Comments
 (0)