Skip to content

Commit bb4d374

Browse files
committed
Auto merge of #115864 - compiler-errors:rpitit-sugg, r=<try>
Suggest desugaring to return-position `impl Future` when an `async fn` in trait fails an auto trait bound First commit allows us to store the span of the `async` keyword in HIR. Second commit implements a suggestion to desugar an `async fn` to a return-position `impl Future` in trait to slightly improve the `Send` situation being discussed in #115822. This suggestion is only made when `#![feature(return_type_notation)]` is not enabled -- if it is, we should instead suggest an appropriate where-clause bound.
2 parents e7f9f48 + fae8c23 commit bb4d374

File tree

27 files changed

+276
-38
lines changed

27 files changed

+276
-38
lines changed

compiler/rustc_ast_lowering/src/item.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
13081308

13091309
fn lower_asyncness(&mut self, a: Async) -> hir::IsAsync {
13101310
match a {
1311-
Async::Yes { .. } => hir::IsAsync::Async,
1311+
Async::Yes { span, .. } => hir::IsAsync::Async(span),
13121312
Async::No => hir::IsAsync::NotAsync,
13131313
}
13141314
}

compiler/rustc_borrowck/src/diagnostics/region_name.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
302302
if free_region.bound_region.is_named() {
303303
// A named region that is actually named.
304304
Some(RegionName { name, source: RegionNameSource::NamedFreeRegion(span) })
305-
} else if let hir::IsAsync::Async = tcx.asyncness(self.mir_hir_id().owner) {
305+
} else if tcx.asyncness(self.mir_hir_id().owner).is_async() {
306306
// If we spuriously thought that the region is named, we should let the
307307
// system generate a true name for error messages. Currently this can
308308
// happen if we have an elided name in an async fn for example: the

compiler/rustc_hir/src/hir.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -2853,13 +2853,13 @@ impl ImplicitSelfKind {
28532853
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
28542854
#[derive(HashStable_Generic)]
28552855
pub enum IsAsync {
2856-
Async,
2856+
Async(Span),
28572857
NotAsync,
28582858
}
28592859

28602860
impl IsAsync {
28612861
pub fn is_async(self) -> bool {
2862-
self == IsAsync::Async
2862+
matches!(self, IsAsync::Async(_))
28632863
}
28642864
}
28652865

@@ -3296,7 +3296,7 @@ pub struct FnHeader {
32963296

32973297
impl FnHeader {
32983298
pub fn is_async(&self) -> bool {
3299-
matches!(&self.asyncness, IsAsync::Async)
3299+
matches!(&self.asyncness, IsAsync::Async(_))
33003300
}
33013301

33023302
pub fn is_const(&self) -> bool {
@@ -4091,10 +4091,10 @@ mod size_asserts {
40914091
static_assert_size!(GenericBound<'_>, 48);
40924092
static_assert_size!(Generics<'_>, 56);
40934093
static_assert_size!(Impl<'_>, 80);
4094-
static_assert_size!(ImplItem<'_>, 80);
4095-
static_assert_size!(ImplItemKind<'_>, 32);
4096-
static_assert_size!(Item<'_>, 80);
4097-
static_assert_size!(ItemKind<'_>, 48);
4094+
static_assert_size!(ImplItem<'_>, 88);
4095+
static_assert_size!(ImplItemKind<'_>, 40);
4096+
static_assert_size!(Item<'_>, 88);
4097+
static_assert_size!(ItemKind<'_>, 56);
40984098
static_assert_size!(Local<'_>, 64);
40994099
static_assert_size!(Param<'_>, 32);
41004100
static_assert_size!(Pat<'_>, 72);
@@ -4105,8 +4105,8 @@ mod size_asserts {
41054105
static_assert_size!(Res, 12);
41064106
static_assert_size!(Stmt<'_>, 32);
41074107
static_assert_size!(StmtKind<'_>, 16);
4108-
static_assert_size!(TraitItem<'_>, 80);
4109-
static_assert_size!(TraitItemKind<'_>, 40);
4108+
static_assert_size!(TraitItem<'_>, 88);
4109+
static_assert_size!(TraitItemKind<'_>, 48);
41104110
static_assert_size!(Ty<'_>, 48);
41114111
static_assert_size!(TyKind<'_>, 32);
41124112
// tidy-alphabetical-end

compiler/rustc_hir_analysis/src/check/compare_impl_item.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ fn compare_asyncness<'tcx>(
595595
trait_m: ty::AssocItem,
596596
delay: bool,
597597
) -> Result<(), ErrorGuaranteed> {
598-
if tcx.asyncness(trait_m.def_id) == hir::IsAsync::Async {
598+
if tcx.asyncness(trait_m.def_id).is_async() {
599599
match tcx.fn_sig(impl_m.def_id).skip_binder().skip_binder().output().kind() {
600600
ty::Alias(ty::Opaque, ..) => {
601601
// allow both `async fn foo()` and `fn foo() -> impl Future`

compiler/rustc_hir_analysis/src/check/entry.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ fn check_main_fn_ty(tcx: TyCtxt<'_>, main_def_id: DefId) {
112112
}
113113

114114
let main_asyncness = tcx.asyncness(main_def_id);
115-
if let hir::IsAsync::Async = main_asyncness {
115+
if main_asyncness.is_async() {
116116
let asyncness_span = main_fn_asyncness_span(tcx, main_def_id);
117117
tcx.sess.emit_err(errors::MainFunctionAsync { span: main_span, asyncness: asyncness_span });
118118
error = true;
@@ -212,7 +212,7 @@ fn check_start_fn_ty(tcx: TyCtxt<'_>, start_def_id: DefId) {
212212
});
213213
error = true;
214214
}
215-
if let hir::IsAsync::Async = sig.header.asyncness {
215+
if sig.header.asyncness.is_async() {
216216
let span = tcx.def_span(it.owner_id);
217217
tcx.sess.emit_err(errors::StartAsync { span: span });
218218
error = true;

compiler/rustc_hir_analysis/src/collect/resolve_bound_vars.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
12061206
&& let Some(generics) = self.tcx.hir().get_generics(self.tcx.local_parent(param_id))
12071207
&& let Some(param) = generics.params.iter().find(|p| p.def_id == param_id)
12081208
&& param.is_elided_lifetime()
1209-
&& let hir::IsAsync::NotAsync = self.tcx.asyncness(lifetime_ref.hir_id.owner.def_id)
1209+
&& !self.tcx.asyncness(lifetime_ref.hir_id.owner.def_id).is_async()
12101210
&& !self.tcx.features().anonymous_lifetime_in_impl_trait
12111211
{
12121212
let mut diag = rustc_session::parse::feature_err(

compiler/rustc_hir_pretty/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2304,7 +2304,7 @@ impl<'a> State<'a> {
23042304

23052305
match header.asyncness {
23062306
hir::IsAsync::NotAsync => {}
2307-
hir::IsAsync::Async => self.word_nbsp("async"),
2307+
hir::IsAsync::Async(_) => self.word_nbsp("async"),
23082308
}
23092309

23102310
self.print_unsafety(header.unsafety);

compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -987,10 +987,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
987987
let bound_vars = self.tcx.late_bound_vars(fn_id);
988988
let ty = self.tcx.erase_late_bound_regions(Binder::bind_with_vars(ty, bound_vars));
989989
let ty = match self.tcx.asyncness(fn_id.owner) {
990-
hir::IsAsync::Async => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
990+
ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
991991
span_bug!(fn_decl.output.span(), "failed to get output type of async function")
992992
}),
993-
hir::IsAsync::NotAsync => ty,
993+
ty::Asyncness::No => ty,
994994
};
995995
let ty = self.normalize(expr.span, ty);
996996
if self.can_coerce(found, ty) {

compiler/rustc_lint/src/builtin.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ use crate::{
4141
},
4242
EarlyContext, EarlyLintPass, LateContext, LateLintPass, Level, LintContext,
4343
};
44-
use hir::IsAsync;
4544
use rustc_ast::attr;
4645
use rustc_ast::tokenstream::{TokenStream, TokenTree};
4746
use rustc_ast::visit::{FnCtxt, FnKind};
@@ -1294,7 +1293,7 @@ impl<'tcx> LateLintPass<'tcx> for UngatedAsyncFnTrackCaller {
12941293
span: Span,
12951294
def_id: LocalDefId,
12961295
) {
1297-
if fn_kind.asyncness() == IsAsync::Async
1296+
if fn_kind.asyncness().is_async()
12981297
&& !cx.tcx.features().async_fn_track_caller
12991298
// Now, check if the function has the `#[track_caller]` attribute
13001299
&& let Some(attr) = cx.tcx.get_attr(def_id, sym::track_caller)

compiler/rustc_metadata/src/rmeta/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ define_tables! {
439439
coerce_unsized_info: Table<DefIndex, LazyValue<ty::adjustment::CoerceUnsizedInfo>>,
440440
mir_const_qualif: Table<DefIndex, LazyValue<mir::ConstQualifs>>,
441441
rendered_const: Table<DefIndex, LazyValue<String>>,
442-
asyncness: Table<DefIndex, hir::IsAsync>,
442+
asyncness: Table<DefIndex, ty::Asyncness>,
443443
fn_arg_names: Table<DefIndex, LazyArray<Ident>>,
444444
generator_kind: Table<DefIndex, LazyValue<hir::GeneratorKind>>,
445445
trait_def: Table<DefIndex, LazyValue<ty::TraitDef>>,

compiler/rustc_metadata/src/rmeta/table.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ fixed_size_enum! {
206206
}
207207

208208
fixed_size_enum! {
209-
hir::IsAsync {
210-
( NotAsync )
211-
( Async )
209+
ty::Asyncness {
210+
( Yes )
211+
( No )
212212
}
213213
}
214214

compiler/rustc_middle/src/query/erase.rs

+1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ trivial! {
270270
rustc_middle::ty::adjustment::CoerceUnsizedInfo,
271271
rustc_middle::ty::AssocItem,
272272
rustc_middle::ty::AssocItemContainer,
273+
rustc_middle::ty::Asyncness,
273274
rustc_middle::ty::BoundVariableKind,
274275
rustc_middle::ty::DeducedParamAttrs,
275276
rustc_middle::ty::Destructor,

compiler/rustc_middle/src/query/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ rustc_queries! {
731731
separate_provide_extern
732732
}
733733

734-
query asyncness(key: DefId) -> hir::IsAsync {
734+
query asyncness(key: DefId) -> ty::Asyncness {
735735
desc { |tcx| "checking if the function is async: `{}`", tcx.def_path_str(key) }
736736
separate_provide_extern
737737
}

compiler/rustc_middle/src/ty/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,19 @@ impl fmt::Display for ImplPolarity {
280280
}
281281
}
282282

283+
#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable, Debug)]
284+
#[derive(TypeFoldable, TypeVisitable)]
285+
pub enum Asyncness {
286+
Yes,
287+
No,
288+
}
289+
290+
impl Asyncness {
291+
pub fn is_async(self) -> bool {
292+
matches!(self, Asyncness::Yes)
293+
}
294+
}
295+
283296
#[derive(Clone, Debug, PartialEq, Eq, Copy, Hash, Encodable, Decodable, HashStable)]
284297
pub enum Visibility<Id = LocalDefId> {
285298
/// Visible everywhere (including in other crates).

compiler/rustc_middle/src/ty/parameterized.rs

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ trivially_parameterized_over_tcx! {
6262
crate::middle::resolve_bound_vars::ObjectLifetimeDefault,
6363
crate::mir::ConstQualifs,
6464
ty::AssocItemContainer,
65+
ty::Asyncness,
6566
ty::DeducedParamAttrs,
6667
ty::Generics,
6768
ty::ImplPolarity,

compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,8 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
986986
}
987987
}
988988

989+
self.suggest_desugaring_async_fn_in_trait(&mut err, trait_ref);
990+
989991
// Return early if the trait is Debug or Display and the invocation
990992
// originates within a standard library macro, because the output
991993
// is otherwise overwhelming and unhelpful (see #85844 for an

compiler/rustc_trait_selection/src/traits/error_reporting/on_unimplemented.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
103103
hir::Node::Item(hir::Item { kind: hir::ItemKind::Fn(sig, _, body_id), .. }) => {
104104
self.describe_generator(*body_id).or_else(|| {
105105
Some(match sig.header {
106-
hir::FnHeader { asyncness: hir::IsAsync::Async, .. } => "an async function",
106+
hir::FnHeader { asyncness: hir::IsAsync::Async(_), .. } => {
107+
"an async function"
108+
}
107109
_ => "a function",
108110
})
109111
})
@@ -117,7 +119,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
117119
..
118120
}) => self.describe_generator(*body_id).or_else(|| {
119121
Some(match sig.header {
120-
hir::FnHeader { asyncness: hir::IsAsync::Async, .. } => "an async method",
122+
hir::FnHeader { asyncness: hir::IsAsync::Async(_), .. } => "an async method",
121123
_ => "a method",
122124
})
123125
}),

compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs

+124
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,12 @@ pub trait TypeErrCtxtExt<'tcx> {
406406
candidate_impls: &[ImplCandidate<'tcx>],
407407
span: Span,
408408
);
409+
410+
fn suggest_desugaring_async_fn_in_trait(
411+
&self,
412+
err: &mut Diagnostic,
413+
trait_ref: ty::PolyTraitRef<'tcx>,
414+
);
409415
}
410416

411417
fn predicate_constraint(generics: &hir::Generics<'_>, pred: ty::Predicate<'_>) -> (Span, String) {
@@ -4027,6 +4033,124 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
40274033
}
40284034
}
40294035
}
4036+
4037+
fn suggest_desugaring_async_fn_in_trait(
4038+
&self,
4039+
err: &mut Diagnostic,
4040+
trait_ref: ty::PolyTraitRef<'tcx>,
4041+
) {
4042+
// Don't suggest if RTN is active -- we should prefer a where-clause bound instead.
4043+
if self.tcx.features().return_type_notation {
4044+
return;
4045+
}
4046+
4047+
let trait_def_id = trait_ref.def_id();
4048+
4049+
// Only suggest specifying auto traits
4050+
if !self.tcx.trait_is_auto(trait_def_id) {
4051+
return;
4052+
}
4053+
4054+
// Look for an RPITIT
4055+
let ty::Alias(ty::Projection, alias_ty) = trait_ref.self_ty().skip_binder().kind() else {
4056+
return;
4057+
};
4058+
let Some(ty::ImplTraitInTraitData::Trait { fn_def_id, opaque_def_id }) =
4059+
self.tcx.opt_rpitit_info(alias_ty.def_id)
4060+
else {
4061+
return;
4062+
};
4063+
4064+
// ... which is a local function
4065+
let Some(fn_def_id) = fn_def_id.as_local() else {
4066+
return;
4067+
};
4068+
let Some(hir::Node::TraitItem(item)) = self.tcx.hir().find_by_def_id(fn_def_id) else {
4069+
return;
4070+
};
4071+
4072+
// ... whose signature is `async` (i.e. this is an AFIT)
4073+
let (sig, body) = item.expect_fn();
4074+
let hir::IsAsync::Async(async_span) = sig.header.asyncness else {
4075+
return;
4076+
};
4077+
let Ok(async_span) =
4078+
self.tcx.sess.source_map().span_extend_while(async_span, |c| c.is_whitespace())
4079+
else {
4080+
return;
4081+
};
4082+
let hir::FnRetTy::Return(hir::Ty { kind: hir::TyKind::OpaqueDef(def, ..), .. }) =
4083+
sig.decl.output
4084+
else {
4085+
// This should never happen, but let's not ICE.
4086+
return;
4087+
};
4088+
4089+
// Check that this is *not* a nested `impl Future` RPIT in an async fn
4090+
// (i.e. `async fn foo() -> impl Future`)
4091+
if def.owner_id.to_def_id() != opaque_def_id {
4092+
return;
4093+
}
4094+
4095+
let future = self.tcx.hir().item(*def).expect_opaque_ty();
4096+
let Some(hir::GenericBound::LangItemTrait(_, _, _, generics)) = future.bounds.get(0) else {
4097+
// `async fn` should always lower to a lang item bound... but don't ICE.
4098+
return;
4099+
};
4100+
let Some(hir::TypeBindingKind::Equality { term: hir::Term::Ty(future_output_ty) }) =
4101+
generics.bindings.get(0).map(|binding| binding.kind)
4102+
else {
4103+
// Also should never happen.
4104+
return;
4105+
};
4106+
4107+
let function_name = self.tcx.def_path_str(fn_def_id);
4108+
let auto_trait = self.tcx.def_path_str(trait_def_id);
4109+
4110+
let mut sugg = if future_output_ty.span.is_empty() {
4111+
vec![
4112+
(async_span, String::new()),
4113+
(
4114+
future_output_ty.span,
4115+
format!(" -> impl std::future::Future<Output = ()> + {auto_trait}"),
4116+
),
4117+
]
4118+
} else {
4119+
vec![
4120+
(
4121+
future_output_ty.span.shrink_to_lo(),
4122+
"impl std::future::Future<Output = ".to_owned(),
4123+
),
4124+
(future_output_ty.span.shrink_to_hi(), format!("> + {auto_trait}")),
4125+
(async_span, String::new()),
4126+
]
4127+
};
4128+
4129+
// If there's a body, we also need to wrap it in `async {}`
4130+
if let hir::TraitFn::Provided(body) = body {
4131+
let body = self.tcx.hir().body(*body);
4132+
let body_span = body.value.span;
4133+
let body_span_without_braces =
4134+
body_span.with_lo(body_span.lo() + BytePos(1)).with_hi(body_span.hi() - BytePos(1));
4135+
if body_span_without_braces.is_empty() {
4136+
sugg.push((body_span_without_braces, " async {} ".to_owned()));
4137+
} else {
4138+
sugg.extend([
4139+
(body_span_without_braces.shrink_to_lo(), "async {".to_owned()),
4140+
(body_span_without_braces.shrink_to_hi(), "} ".to_owned()),
4141+
]);
4142+
}
4143+
}
4144+
4145+
err.multipart_suggestion(
4146+
format!(
4147+
"`{auto_trait}` can be made part of the associated future's \
4148+
guarantees for all implementations of `{function_name}`"
4149+
),
4150+
sugg,
4151+
Applicability::MachineApplicable,
4152+
);
4153+
}
40304154
}
40314155

40324156
/// Add a hint to add a missing borrow or remove an unnecessary one.

compiler/rustc_ty_utils/src/ty.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,12 @@ fn issue33140_self_ty(tcx: TyCtxt<'_>, def_id: DefId) -> Option<EarlyBinder<Ty<'
299299
}
300300

301301
/// Check if a function is async.
302-
fn asyncness(tcx: TyCtxt<'_>, def_id: LocalDefId) -> hir::IsAsync {
302+
fn asyncness(tcx: TyCtxt<'_>, def_id: LocalDefId) -> ty::Asyncness {
303303
let node = tcx.hir().get_by_def_id(def_id);
304-
node.fn_sig().map_or(hir::IsAsync::NotAsync, |sig| sig.header.asyncness)
304+
node.fn_sig().map_or(ty::Asyncness::No, |sig| match sig.header.asyncness {
305+
hir::IsAsync::Async(_) => ty::Asyncness::Yes,
306+
hir::IsAsync::NotAsync => ty::Asyncness::No,
307+
})
305308
}
306309

307310
fn unsizing_params_for_adt<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> BitSet<u32> {

0 commit comments

Comments
 (0)