Skip to content

Commit 60c0ad8

Browse files
committed
Avoid GenFuture shim when compiling async constructs
Previously, async constructs would be lowered to "normal" generators, with an additional `from_generator` / `GenFuture` shim in between to convert from `Generator` to `Future`. The compiler will now special-case these generators internally so that async constructs will *directly* implement `Future` without the need to go through the `from_generator` / `GenFuture` shim. The primary motivation for this change was hiding this implementation detail in stack traces and debuginfo, but it can in theory also help the optimizer as there is less abstractions to see through.
1 parent 91385d5 commit 60c0ad8

File tree

14 files changed

+209
-110
lines changed

14 files changed

+209
-110
lines changed

compiler/rustc_ast_lowering/src/expr.rs

+32-33
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
136136
self.arena.alloc_from_iter(arms.iter().map(|x| self.lower_arm(x))),
137137
hir::MatchSource::Normal,
138138
),
139-
ExprKind::Async(capture_clause, closure_node_id, ref block) => self
140-
.make_async_expr(
139+
ExprKind::Async(capture_clause, closure_node_id, ref block) => {
140+
return self.make_async_expr(
141141
capture_clause,
142142
closure_node_id,
143143
None,
144144
block.span,
145145
hir::AsyncGeneratorKind::Block,
146146
|this| this.with_new_scopes(|this| this.lower_block_expr(block)),
147-
),
147+
);
148+
}
148149
ExprKind::Await(ref expr) => {
149150
let dot_await_span = if expr.span.hi() < e.span.hi() {
150151
let span_with_whitespace = self
@@ -563,14 +564,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
563564
}
564565
}
565566

566-
/// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
567+
/// Lower an `async` construct to a generator that implements `Future`.
567568
///
568569
/// This results in:
569570
///
570571
/// ```text
571-
/// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
572+
/// static move? |_task_context| -> <ret_ty> {
572573
/// <body>
573-
/// })
574+
/// }
574575
/// ```
575576
pub(super) fn make_async_expr(
576577
&mut self,
@@ -580,19 +581,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
580581
span: Span,
581582
async_gen_kind: hir::AsyncGeneratorKind,
582583
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
583-
) -> hir::ExprKind<'hir> {
584+
) -> hir::Expr<'hir> {
584585
let output = match ret_ty {
585586
Some(ty) => hir::FnRetTy::Return(
586587
self.lower_ty(&ty, &ImplTraitContext::Disallowed(ImplTraitPosition::AsyncBlock)),
587588
),
588589
None => hir::FnRetTy::DefaultReturn(self.lower_span(span)),
589590
};
590591

591-
// Resume argument type. We let the compiler infer this to simplify the lowering. It is
592-
// fully constrained by `future::from_generator`.
592+
// Resume argument type, which should be `&mut Context<'_>`
593+
let context_lifetime = self.arena.alloc(hir::Lifetime {
594+
hir_id: self.next_id(),
595+
span: self.lower_span(span),
596+
name: hir::LifetimeName::Infer,
597+
});
598+
let context_path =
599+
hir::QPath::LangItem(hir::LangItem::Context, self.lower_span(span), None);
600+
let context_ty = hir::MutTy {
601+
ty: self.arena.alloc(hir::Ty {
602+
hir_id: self.next_id(),
603+
kind: hir::TyKind::Path(context_path),
604+
span: self.lower_span(span),
605+
}),
606+
mutbl: hir::Mutability::Mut,
607+
};
593608
let input_ty = hir::Ty {
594609
hir_id: self.next_id(),
595-
kind: hir::TyKind::Infer,
610+
kind: hir::TyKind::Rptr(context_lifetime, context_ty),
596611
span: self.lower_span(span),
597612
};
598613

@@ -642,24 +657,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
642657

643658
hir::ExprKind::Closure(c)
644659
};
660+
645661
let generator = hir::Expr {
646662
hir_id: self.lower_node_id(closure_node_id),
647663
kind: generator_kind,
648664
span: self.lower_span(span),
649665
};
650666

651-
// `future::from_generator`:
652-
let unstable_span =
653-
self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone());
654-
let gen_future = self.expr_lang_item_path(
655-
unstable_span,
656-
hir::LangItem::FromGenerator,
657-
AttrVec::new(),
658-
None,
659-
);
660-
661-
// `future::from_generator(generator)`:
662-
hir::ExprKind::Call(self.arena.alloc(gen_future), arena_vec![self; generator])
667+
generator
663668
}
664669

665670
/// Desugar `<expr>.await` into:
@@ -668,7 +673,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
668673
/// mut __awaitee => loop {
669674
/// match unsafe { ::std::future::Future::poll(
670675
/// <::std::pin::Pin>::new_unchecked(&mut __awaitee),
671-
/// ::std::future::get_context(task_context),
676+
/// task_context,
672677
/// ) } {
673678
/// ::std::task::Poll::Ready(result) => break result,
674679
/// ::std::task::Poll::Pending => {}
@@ -709,7 +714,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
709714
// unsafe {
710715
// ::std::future::Future::poll(
711716
// ::std::pin::Pin::new_unchecked(&mut __awaitee),
712-
// ::std::future::get_context(task_context),
717+
// task_context,
713718
// )
714719
// }
715720
let poll_expr = {
@@ -727,16 +732,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
727732
arena_vec![self; ref_mut_awaitee],
728733
Some(expr_hir_id),
729734
);
730-
let get_context = self.expr_call_lang_item_fn_mut(
731-
gen_future_span,
732-
hir::LangItem::GetContext,
733-
arena_vec![self; task_context],
734-
Some(expr_hir_id),
735-
);
736735
let call = self.expr_call_lang_item_fn(
737736
span,
738737
hir::LangItem::FuturePoll,
739-
arena_vec![self; new_unchecked, get_context],
738+
arena_vec![self; new_unchecked, task_context],
740739
Some(expr_hir_id),
741740
);
742741
self.arena.alloc(self.expr_unsafe(call))
@@ -962,7 +961,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
962961
}
963962

964963
// Transform `async |x: u8| -> X { ... }` into
965-
// `|x: u8| future_from_generator(|| -> X { ... })`.
964+
// `|x: u8| || -> X { ... }`.
966965
let body_id = this.lower_fn_body(&outer_decl, |this| {
967966
let async_ret_ty =
968967
if let FnRetTy::Ty(ty) = &decl.output { Some(ty.clone()) } else { None };
@@ -974,7 +973,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
974973
hir::AsyncGeneratorKind::Closure,
975974
|this| this.with_new_scopes(|this| this.lower_expr_mut(body)),
976975
);
977-
this.expr(fn_decl_span, async_body, AttrVec::new())
976+
async_body
978977
});
979978
body_id
980979
});

compiler/rustc_ast_lowering/src/item.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -1248,10 +1248,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
12481248
},
12491249
);
12501250

1251-
(
1252-
this.arena.alloc_from_iter(parameters),
1253-
this.expr(body.span, async_expr, AttrVec::new()),
1254-
)
1251+
(this.arena.alloc_from_iter(parameters), async_expr)
12551252
})
12561253
}
12571254

compiler/rustc_hir/src/lang_items.rs

+2
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,12 @@ language_item_table! {
270270
TryTraitBranch, sym::branch, branch_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
271271
TryTraitFromYeet, sym::from_yeet, from_yeet_fn, Target::Fn, GenericRequirement::None;
272272

273+
Poll, sym::Poll, poll, Target::Enum, GenericRequirement::None;
273274
PollReady, sym::Ready, poll_ready_variant, Target::Variant, GenericRequirement::None;
274275
PollPending, sym::Pending, poll_pending_variant, Target::Variant, GenericRequirement::None;
275276

276277
FromGenerator, sym::from_generator, from_generator_fn, Target::Fn, GenericRequirement::None;
278+
Context, sym::Context, context, Target::Struct, GenericRequirement::None;
277279
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;
278280

279281
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;

compiler/rustc_hir_typeck/src/check.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,15 @@ pub(super) fn check_fn<'a, 'tcx>(
5454

5555
fn_maybe_err(tcx, span, fn_sig.abi);
5656

57-
if body.generator_kind.is_some() && can_be_generator.is_some() {
58-
let yield_ty = fcx
59-
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span });
60-
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
57+
if let Some(kind) = body.generator_kind && can_be_generator.is_some() {
58+
let yield_ty = if kind == hir::GeneratorKind::Gen {
59+
let yield_ty = fcx
60+
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span });
61+
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
62+
yield_ty
63+
} else {
64+
tcx.mk_unit()
65+
};
6166

6267
// Resume type defaults to `()` if the generator has no argument.
6368
let resume_ty = fn_sig.inputs().get(0).copied().unwrap_or_else(|| tcx.mk_unit());

compiler/rustc_middle/src/traits/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ impl<'tcx, N> ImplSource<'tcx, N> {
736736
generator_def_id: c.generator_def_id,
737737
substs: c.substs,
738738
nested: c.nested.into_iter().map(f).collect(),
739+
is_async: c.is_async,
739740
}),
740741
ImplSource::FnPointer(p) => ImplSource::FnPointer(ImplSourceFnPointerData {
741742
fn_ty: p.fn_ty,
@@ -794,6 +795,7 @@ pub struct ImplSourceGeneratorData<'tcx, N> {
794795
/// Nested obligations. This can be non-empty if the generator
795796
/// signature contains associated types.
796797
pub nested: Vec<N>,
798+
pub is_async: bool,
797799
}
798800

799801
#[derive(Clone, PartialEq, Eq, TyEncodable, TyDecodable, HashStable, Lift)]

compiler/rustc_middle/src/traits/select.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,11 @@ pub enum SelectionCandidate<'tcx> {
127127
/// generated for an `||` expression.
128128
ClosureCandidate,
129129

130-
/// Implementation of a `Generator` trait by one of the anonymous types
130+
/// Implementation of a `Generator` / `Future` trait by one of the anonymous types
131131
/// generated for a generator.
132-
GeneratorCandidate,
132+
GeneratorCandidate {
133+
is_async: bool,
134+
},
133135

134136
/// Implementation of a `Fn`-family trait by one of the anonymous
135137
/// types generated for a fn pointer type (e.g., `fn(int) -> int`)

compiler/rustc_mir_transform/src/generator.rs

+68-30
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use crate::deref_separator::deref_finder;
5353
use crate::simplify;
5454
use crate::util::expand_aggregate;
5555
use crate::MirPass;
56+
use hir::GeneratorKind;
5657
use rustc_data_structures::fx::FxHashMap;
5758
use rustc_hir as hir;
5859
use rustc_hir::lang_items::LangItem;
@@ -215,6 +216,7 @@ struct SuspensionPoint<'tcx> {
215216

216217
struct TransformVisitor<'tcx> {
217218
tcx: TyCtxt<'tcx>,
219+
is_async_kind: bool,
218220
state_adt_ref: AdtDef<'tcx>,
219221
state_substs: SubstsRef<'tcx>,
220222

@@ -239,28 +241,30 @@ struct TransformVisitor<'tcx> {
239241
}
240242

241243
impl<'tcx> TransformVisitor<'tcx> {
242-
// Make a GeneratorState variant assignment. `core::ops::GeneratorState` only has single
243-
// element tuple variants, so we can just write to the downcasted first field and then set the
244+
// Make a `GeneratorState` or `Poll` variant assignment.
245+
//
246+
// `core::ops::GeneratorState` only has single element tuple variants,
247+
// so we can just write to the downcasted first field and then set the
244248
// discriminant to the appropriate variant.
245-
fn make_state(
246-
&self,
247-
idx: VariantIdx,
248-
val: Operand<'tcx>,
249-
source_info: SourceInfo,
250-
) -> impl Iterator<Item = Statement<'tcx>> {
249+
fn make_state(&self, idx: VariantIdx) -> (AggregateKind<'tcx>, Option<Ty<'tcx>>) {
251250
let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None);
251+
252+
// `Poll::Pending`
253+
if self.is_async_kind && idx == VariantIdx::new(1) {
254+
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
255+
256+
return (kind, None);
257+
}
258+
259+
// else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
252260
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
261+
253262
let ty = self
254263
.tcx
255264
.bound_type_of(self.state_adt_ref.variant(idx).fields[0].did)
256265
.subst(self.tcx, self.state_substs);
257-
expand_aggregate(
258-
Place::return_place(),
259-
std::iter::once((val, ty)),
260-
kind,
261-
source_info,
262-
self.tcx,
263-
)
266+
267+
(kind, Some(ty))
264268
}
265269

266270
// Create a Place referencing a generator struct field
@@ -331,22 +335,44 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
331335
});
332336

333337
let ret_val = match data.terminator().kind {
334-
TerminatorKind::Return => Some((
335-
VariantIdx::new(1),
336-
None,
337-
Operand::Move(Place::from(self.new_ret_local)),
338-
None,
339-
)),
338+
TerminatorKind::Return => {
339+
Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
340+
}
340341
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
341-
Some((VariantIdx::new(0), Some((resume, resume_arg)), value.clone(), drop))
342+
Some((false, Some((resume, resume_arg)), value.clone(), drop))
342343
}
343344
_ => None,
344345
};
345346

346-
if let Some((state_idx, resume, v, drop)) = ret_val {
347+
if let Some((is_return, resume, v, drop)) = ret_val {
347348
let source_info = data.terminator().source_info;
348349
// We must assign the value first in case it gets declared dead below
349-
data.statements.extend(self.make_state(state_idx, v, source_info));
350+
let state_idx = VariantIdx::new(match (is_return, self.is_async_kind) {
351+
(true, false) => 1, // GeneratorState::Complete
352+
(false, false) => 0, // GeneratorState::Yielded
353+
(true, true) => 0, // Poll::Ready
354+
(false, true) => 1, // Poll::Pending
355+
});
356+
let (kind, ty) = self.make_state(state_idx);
357+
if let Some(ty) = ty {
358+
data.statements.extend(expand_aggregate(
359+
Place::return_place(),
360+
std::iter::once((v, ty)),
361+
kind,
362+
source_info,
363+
self.tcx,
364+
));
365+
} else {
366+
// TODO: assert `val` is nil
367+
data.statements.extend(expand_aggregate(
368+
Place::return_place(),
369+
std::iter::empty(),
370+
kind,
371+
source_info,
372+
self.tcx,
373+
));
374+
}
375+
350376
let state = if let Some((resume, mut resume_arg)) = resume {
351377
// Yield
352378
let state = RESERVED_VARIANTS + self.suspension_points.len();
@@ -1268,10 +1294,20 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
12681294
}
12691295
};
12701296

1271-
// Compute GeneratorState<yield_ty, return_ty>
1272-
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
1273-
let state_adt_ref = tcx.adt_def(state_did);
1274-
let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]);
1297+
let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen;
1298+
let (state_adt_ref, state_substs) = if is_async_kind {
1299+
// Compute Poll<return_ty>
1300+
let state_did = tcx.require_lang_item(LangItem::Poll, None);
1301+
let state_adt_ref = tcx.adt_def(state_did);
1302+
let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
1303+
(state_adt_ref, state_substs)
1304+
} else {
1305+
// Compute GeneratorState<yield_ty, return_ty>
1306+
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
1307+
let state_adt_ref = tcx.adt_def(state_did);
1308+
let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]);
1309+
(state_adt_ref, state_substs)
1310+
};
12751311
let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);
12761312

12771313
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
@@ -1327,9 +1363,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
13271363
// Run the transformation which converts Places from Local to generator struct
13281364
// accesses for locals in `remap`.
13291365
// It also rewrites `return x` and `yield y` as writing a new generator state and returning
1330-
// GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively.
1366+
// either GeneratorState::Complete(x) and GeneratorState::Yielded(y),
1367+
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
13311368
let mut transform = TransformVisitor {
13321369
tcx,
1370+
is_async_kind,
13331371
state_adt_ref,
13341372
state_substs,
13351373
remap,
@@ -1367,7 +1405,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
13671405

13681406
body.generator.as_mut().unwrap().generator_drop = Some(drop_shim);
13691407

1370-
// Create the Generator::resume function
1408+
// Create the Generator::resume / Future::poll function
13711409
create_generator_resume_function(tcx, transform, body, can_return);
13721410

13731411
// Run derefer to fix Derefs that are not in the first place

0 commit comments

Comments
 (0)