Skip to content

Suggest desugaring to return-position impl Future when an async fn in trait fails an auto trait bound #115864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ impl<'hir> LoweringContext<'_, 'hir> {

fn lower_asyncness(&mut self, a: Async) -> hir::IsAsync {
match a {
Async::Yes { .. } => hir::IsAsync::Async,
Async::Yes { span, .. } => hir::IsAsync::Async(span),
Async::No => hir::IsAsync::NotAsync,
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/diagnostics/region_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
if free_region.bound_region.is_named() {
// A named region that is actually named.
Some(RegionName { name, source: RegionNameSource::NamedFreeRegion(span) })
} else if let hir::IsAsync::Async = tcx.asyncness(self.mir_hir_id().owner) {
} else if tcx.asyncness(self.mir_hir_id().owner).is_async() {
// If we spuriously thought that the region is named, we should let the
// system generate a true name for error messages. Currently this can
// happen if we have an elided name in an async fn for example: the
Expand Down
18 changes: 9 additions & 9 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2853,13 +2853,13 @@ impl ImplicitSelfKind {
#[derive(Copy, Clone, PartialEq, Eq, Encodable, Decodable, Debug)]
#[derive(HashStable_Generic)]
pub enum IsAsync {
Async,
Async(Span),
NotAsync,
}

impl IsAsync {
pub fn is_async(self) -> bool {
self == IsAsync::Async
matches!(self, IsAsync::Async(_))
}
}

Expand Down Expand Up @@ -3296,7 +3296,7 @@ pub struct FnHeader {

impl FnHeader {
pub fn is_async(&self) -> bool {
matches!(&self.asyncness, IsAsync::Async)
matches!(&self.asyncness, IsAsync::Async(_))
}

pub fn is_const(&self) -> bool {
Expand Down Expand Up @@ -4091,10 +4091,10 @@ mod size_asserts {
static_assert_size!(GenericBound<'_>, 48);
static_assert_size!(Generics<'_>, 56);
static_assert_size!(Impl<'_>, 80);
static_assert_size!(ImplItem<'_>, 80);
static_assert_size!(ImplItemKind<'_>, 32);
static_assert_size!(Item<'_>, 80);
static_assert_size!(ItemKind<'_>, 48);
static_assert_size!(ImplItem<'_>, 88);
static_assert_size!(ImplItemKind<'_>, 40);
static_assert_size!(Item<'_>, 88);
static_assert_size!(ItemKind<'_>, 56);
static_assert_size!(Local<'_>, 64);
static_assert_size!(Param<'_>, 32);
static_assert_size!(Pat<'_>, 72);
Expand All @@ -4105,8 +4105,8 @@ mod size_asserts {
static_assert_size!(Res, 12);
static_assert_size!(Stmt<'_>, 32);
static_assert_size!(StmtKind<'_>, 16);
static_assert_size!(TraitItem<'_>, 80);
static_assert_size!(TraitItemKind<'_>, 40);
static_assert_size!(TraitItem<'_>, 88);
static_assert_size!(TraitItemKind<'_>, 48);
static_assert_size!(Ty<'_>, 48);
static_assert_size!(TyKind<'_>, 32);
// tidy-alphabetical-end
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ fn compare_asyncness<'tcx>(
trait_m: ty::AssocItem,
delay: bool,
) -> Result<(), ErrorGuaranteed> {
if tcx.asyncness(trait_m.def_id) == hir::IsAsync::Async {
if tcx.asyncness(trait_m.def_id).is_async() {
match tcx.fn_sig(impl_m.def_id).skip_binder().skip_binder().output().kind() {
ty::Alias(ty::Opaque, ..) => {
// allow both `async fn foo()` and `fn foo() -> impl Future`
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_analysis/src/check/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fn check_main_fn_ty(tcx: TyCtxt<'_>, main_def_id: DefId) {
}

let main_asyncness = tcx.asyncness(main_def_id);
if let hir::IsAsync::Async = main_asyncness {
if main_asyncness.is_async() {
let asyncness_span = main_fn_asyncness_span(tcx, main_def_id);
tcx.sess.emit_err(errors::MainFunctionAsync { span: main_span, asyncness: asyncness_span });
error = true;
Expand Down Expand Up @@ -212,7 +212,7 @@ fn check_start_fn_ty(tcx: TyCtxt<'_>, start_def_id: DefId) {
});
error = true;
}
if let hir::IsAsync::Async = sig.header.asyncness {
if sig.header.asyncness.is_async() {
let span = tcx.def_span(it.owner_id);
tcx.sess.emit_err(errors::StartAsync { span: span });
error = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
&& let Some(generics) = self.tcx.hir().get_generics(self.tcx.local_parent(param_id))
&& let Some(param) = generics.params.iter().find(|p| p.def_id == param_id)
&& param.is_elided_lifetime()
&& let hir::IsAsync::NotAsync = self.tcx.asyncness(lifetime_ref.hir_id.owner.def_id)
&& !self.tcx.asyncness(lifetime_ref.hir_id.owner.def_id).is_async()
&& !self.tcx.features().anonymous_lifetime_in_impl_trait
{
let mut diag = rustc_session::parse::feature_err(
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_pretty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2304,7 +2304,7 @@ impl<'a> State<'a> {

match header.asyncness {
hir::IsAsync::NotAsync => {}
hir::IsAsync::Async => self.word_nbsp("async"),
hir::IsAsync::Async(_) => self.word_nbsp("async"),
}

self.print_unsafety(header.unsafety);
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,10 +987,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let bound_vars = self.tcx.late_bound_vars(fn_id);
let ty = self.tcx.erase_late_bound_regions(Binder::bind_with_vars(ty, bound_vars));
let ty = match self.tcx.asyncness(fn_id.owner) {
hir::IsAsync::Async => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
ty::Asyncness::Yes => self.get_impl_future_output_ty(ty).unwrap_or_else(|| {
span_bug!(fn_decl.output.span(), "failed to get output type of async function")
}),
hir::IsAsync::NotAsync => ty,
ty::Asyncness::No => ty,
};
let ty = self.normalize(expr.span, ty);
if self.can_coerce(found, ty) {
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_lint/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ use crate::{
},
EarlyContext, EarlyLintPass, LateContext, LateLintPass, Level, LintContext,
};
use hir::IsAsync;
use rustc_ast::attr;
use rustc_ast::tokenstream::{TokenStream, TokenTree};
use rustc_ast::visit::{FnCtxt, FnKind};
Expand Down Expand Up @@ -1294,7 +1293,7 @@ impl<'tcx> LateLintPass<'tcx> for UngatedAsyncFnTrackCaller {
span: Span,
def_id: LocalDefId,
) {
if fn_kind.asyncness() == IsAsync::Async
if fn_kind.asyncness().is_async()
&& !cx.tcx.features().async_fn_track_caller
// Now, check if the function has the `#[track_caller]` attribute
&& let Some(attr) = cx.tcx.get_attr(def_id, sym::track_caller)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_metadata/src/rmeta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ define_tables! {
coerce_unsized_info: Table<DefIndex, LazyValue<ty::adjustment::CoerceUnsizedInfo>>,
mir_const_qualif: Table<DefIndex, LazyValue<mir::ConstQualifs>>,
rendered_const: Table<DefIndex, LazyValue<String>>,
asyncness: Table<DefIndex, hir::IsAsync>,
asyncness: Table<DefIndex, ty::Asyncness>,
fn_arg_names: Table<DefIndex, LazyArray<Ident>>,
generator_kind: Table<DefIndex, LazyValue<hir::GeneratorKind>>,
trait_def: Table<DefIndex, LazyValue<ty::TraitDef>>,
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_metadata/src/rmeta/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ fixed_size_enum! {
}

fixed_size_enum! {
hir::IsAsync {
( NotAsync )
( Async )
ty::Asyncness {
( Yes )
( No )
}
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/query/erase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ trivial! {
rustc_middle::ty::adjustment::CoerceUnsizedInfo,
rustc_middle::ty::AssocItem,
rustc_middle::ty::AssocItemContainer,
rustc_middle::ty::Asyncness,
rustc_middle::ty::BoundVariableKind,
rustc_middle::ty::DeducedParamAttrs,
rustc_middle::ty::Destructor,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ rustc_queries! {
separate_provide_extern
}

query asyncness(key: DefId) -> hir::IsAsync {
query asyncness(key: DefId) -> ty::Asyncness {
desc { |tcx| "checking if the function is async: `{}`", tcx.def_path_str(key) }
separate_provide_extern
}
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,19 @@ impl fmt::Display for ImplPolarity {
}
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable, Debug)]
#[derive(TypeFoldable, TypeVisitable)]
pub enum Asyncness {
Yes,
No,
}

impl Asyncness {
pub fn is_async(self) -> bool {
matches!(self, Asyncness::Yes)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Copy, Hash, Encodable, Decodable, HashStable)]
pub enum Visibility<Id = LocalDefId> {
/// Visible everywhere (including in other crates).
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/parameterized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ trivially_parameterized_over_tcx! {
crate::middle::resolve_bound_vars::ObjectLifetimeDefault,
crate::mir::ConstQualifs,
ty::AssocItemContainer,
ty::Asyncness,
ty::DeducedParamAttrs,
ty::Generics,
ty::ImplPolarity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
}

self.explain_hrtb_projection(&mut err, trait_predicate, obligation.param_env, &obligation.cause);
self.suggest_desugaring_async_fn_in_trait(&mut err, trait_ref);

// Return early if the trait is Debug or Display and the invocation
// originates within a standard library macro, because the output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
hir::Node::Item(hir::Item { kind: hir::ItemKind::Fn(sig, _, body_id), .. }) => {
self.describe_generator(*body_id).or_else(|| {
Some(match sig.header {
hir::FnHeader { asyncness: hir::IsAsync::Async, .. } => "an async function",
hir::FnHeader { asyncness: hir::IsAsync::Async(_), .. } => {
"an async function"
}
_ => "a function",
})
})
Expand All @@ -118,7 +120,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
..
}) => self.describe_generator(*body_id).or_else(|| {
Some(match sig.header {
hir::FnHeader { asyncness: hir::IsAsync::Async, .. } => "an async method",
hir::FnHeader { asyncness: hir::IsAsync::Async(_), .. } => "an async method",
_ => "a method",
})
}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ pub trait TypeErrCtxtExt<'tcx> {
param_env: ty::ParamEnv<'tcx>,
cause: &ObligationCause<'tcx>,
);

fn suggest_desugaring_async_fn_in_trait(
&self,
err: &mut Diagnostic,
trait_ref: ty::PolyTraitRef<'tcx>,
);
}

fn predicate_constraint(generics: &hir::Generics<'_>, pred: ty::Predicate<'_>) -> (Span, String) {
Expand Down Expand Up @@ -4100,6 +4106,136 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
});
}
}

fn suggest_desugaring_async_fn_in_trait(
&self,
err: &mut Diagnostic,
trait_ref: ty::PolyTraitRef<'tcx>,
) {
// Don't suggest if RTN is active -- we should prefer a where-clause bound instead.
if self.tcx.features().return_type_notation {
return;
}

let trait_def_id = trait_ref.def_id();

// Only suggest specifying auto traits
if !self.tcx.trait_is_auto(trait_def_id) {
return;
}

// Look for an RPITIT
let ty::Alias(ty::Projection, alias_ty) = trait_ref.self_ty().skip_binder().kind() else {
return;
};
let Some(ty::ImplTraitInTraitData::Trait { fn_def_id, opaque_def_id }) =
self.tcx.opt_rpitit_info(alias_ty.def_id)
else {
return;
};

let auto_trait = self.tcx.def_path_str(trait_def_id);
// ... which is a local function
let Some(fn_def_id) = fn_def_id.as_local() else {
// If it's not local, we can at least mention that the method is async, if it is.
if self.tcx.asyncness(fn_def_id).is_async() {
err.span_note(
self.tcx.def_span(fn_def_id),
format!(
"`{}::{}` is an `async fn` in trait, which does not \
automatically imply that its future is `{auto_trait}`",
alias_ty.trait_ref(self.tcx),
self.tcx.item_name(fn_def_id)
),
);
}
return;
};
let Some(hir::Node::TraitItem(item)) = self.tcx.hir().find_by_def_id(fn_def_id) else {
return;
};

// ... whose signature is `async` (i.e. this is an AFIT)
let (sig, body) = item.expect_fn();
let hir::IsAsync::Async(async_span) = sig.header.asyncness else {
return;
};
let Ok(async_span) =
self.tcx.sess.source_map().span_extend_while(async_span, |c| c.is_whitespace())
else {
return;
};
let hir::FnRetTy::Return(hir::Ty { kind: hir::TyKind::OpaqueDef(def, ..), .. }) =
sig.decl.output
else {
// This should never happen, but let's not ICE.
return;
};

// Check that this is *not* a nested `impl Future` RPIT in an async fn
// (i.e. `async fn foo() -> impl Future`)
if def.owner_id.to_def_id() != opaque_def_id {
return;
}

let future = self.tcx.hir().item(*def).expect_opaque_ty();
let Some(hir::GenericBound::LangItemTrait(_, _, _, generics)) = future.bounds.get(0) else {
// `async fn` should always lower to a lang item bound... but don't ICE.
return;
};
let Some(hir::TypeBindingKind::Equality { term: hir::Term::Ty(future_output_ty) }) =
generics.bindings.get(0).map(|binding| binding.kind)
else {
// Also should never happen.
return;
};

let function_name = self.tcx.def_path_str(fn_def_id);

let mut sugg = if future_output_ty.span.is_empty() {
vec![
(async_span, String::new()),
(
future_output_ty.span,
format!(" -> impl std::future::Future<Output = ()> + {auto_trait}"),
),
]
} else {
vec![
(
future_output_ty.span.shrink_to_lo(),
"impl std::future::Future<Output = ".to_owned(),
),
(future_output_ty.span.shrink_to_hi(), format!("> + {auto_trait}")),
(async_span, String::new()),
]
};

// If there's a body, we also need to wrap it in `async {}`
if let hir::TraitFn::Provided(body) = body {
let body = self.tcx.hir().body(*body);
let body_span = body.value.span;
let body_span_without_braces =
body_span.with_lo(body_span.lo() + BytePos(1)).with_hi(body_span.hi() - BytePos(1));
if body_span_without_braces.is_empty() {
sugg.push((body_span_without_braces, " async {} ".to_owned()));
} else {
sugg.extend([
(body_span_without_braces.shrink_to_lo(), "async {".to_owned()),
(body_span_without_braces.shrink_to_hi(), "} ".to_owned()),
]);
}
}

err.multipart_suggestion(
format!(
"`{auto_trait}` can be made part of the associated future's \
guarantees for all implementations of `{function_name}`"
),
sugg,
Applicability::MachineApplicable,
);
}
}

/// Add a hint to add a missing borrow or remove an unnecessary one.
Expand Down
Loading