Skip to content

Prune unreachable variants of coroutines #135471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions compiler/rustc_mir_transform/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
//! return.

use rustc_abi::{FieldIdx, VariantIdx};
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
use rustc_index::bit_set::DenseBitSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
Expand Down Expand Up @@ -68,6 +72,7 @@ impl SimplifyCfg {

pub(super) fn simplify_cfg(body: &mut Body<'_>) {
CfgSimplifier::new(body).simplify();
remove_dead_coroutine_switch_variants(body);
remove_dead_blocks(body);

// FIXME: Should probably be moved into some kind of pass manager
Expand Down Expand Up @@ -292,6 +297,168 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>)
}
}

const SELF_LOCAL: Local = Local::from_u32(1);
const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0);

pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) {
let Some(coroutine_layout) = body.coroutine_layout_raw() else {
// Not a coroutine; no coroutine variants to remove.
return;
};

let bb0 = &body.basic_blocks[START_BLOCK];

let is_pinned = match body.coroutine_kind().unwrap() {
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false,
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
| CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
| CoroutineKind::Coroutine(_) => true,
};
// This is essentially our off-brand `Underefer`. This stores the set of locals
// that we have determined to contain references to the coroutine discriminant.
// If the self type is not pinned, this is just going to be `_1`. However, if
// the self type is pinned, the derefer will emit statements of the form:
// _x = CopyForDeref (_1.0);
// We'll store the local for `_x` so that we can later detect discriminant stores
// of the form:
// Discriminant((*_x)) = ...
// which correspond to reachable variants of the coroutine.
let mut discr_locals = if is_pinned {
let Some(stmt) = bb0.statements.get(0) else {
// The coroutine body may have been turned into a single `unreachable`.
return;
};
// We match `CopyForDeref` (which is what gets emitted from the state transform
// pass), but also we match *regular* `Copy`, which is what GVN may optimize it to.
let StatementKind::Assign(box (
place,
Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place),
)) = &stmt.kind
else {
panic!("The first statement of a coroutine is not a self deref");
};
let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } =
deref_place.as_ref()
else {
panic!("The first statement of a coroutine is not a self deref");
};
FxHashSet::from_iter([place.as_local().unwrap()])
} else {
FxHashSet::from_iter([SELF_LOCAL])
};

// The starting block of all coroutines is a switch for the coroutine variants.
// This is preceded by a read of the discriminant. If we don't find this, then
// we must have optimized away the switch, so bail.
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) =
&bb0.statements[if is_pinned { 1 } else { 0 }].kind
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we optimize out too much, we may not even have a statement left.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we already check that there's at least 1 statement in https://github.com/rust-lang/rust/pull/135471/files#diff-2b899097018634da368ec3c772523e10def01aef54b950e07c125a49e374a3e3R329. I don't know if there will ever be a case where that succeeds but this fails.

else {
// The following statement is not a discriminant read. We may have
// optimized it out, so bail gracefully.
return;
};
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref()
else {
// We expect the discriminant to have read `&mut self`,
// so we expect the place to be a deref. If we didn't, then
// it may have been optimized out, so bail gracefully.
return;
};
if !discr_locals.contains(&deref_local) {
// The place being read isn't `_1` (self) or a `Derefer`-inserted local.
// It may have been optimized out, so bail gracefully.
return;
}
let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind
else {
// When panic=abort, we may end up folding away the other variants of the
// coroutine, and end up with ths `SwitchInt` getting replaced. In this
// case, there's no need to do this optimization, so bail gracefully.
return;
};
if place != discr_place {
// Make sure we don't try to match on some other `SwitchInt`; we should be
// matching on the discriminant we just read.
return;
}

let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
let mut worklist = vec![];
let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len());

// Insert unresumed (initial), returned, panicked variants.
// We treat these as always reachable.
visited_variants.insert(VariantIdx::from_usize(0));
visited_variants.insert(VariantIdx::from_usize(1));
visited_variants.insert(VariantIdx::from_usize(2));
worklist.push(targets.target_for_value(0));
worklist.push(targets.target_for_value(1));
worklist.push(targets.target_for_value(2));

// Walk all of the reachable variant blocks.
while let Some(block) = worklist.pop() {
if !visited.insert(block) {
continue;
}

let data = &body.basic_blocks[block];
for stmt in &data.statements {
match &stmt.kind {
// If we see a `SetDiscriminant` statement for our coroutine,
// mark that variant as reachable and add it to the worklist.
StatementKind::SetDiscriminant { place, variant_index } => {
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } =
(**place).as_ref()
else {
continue;
};
if !discr_locals.contains(&deref_local) {
continue;
}
visited_variants.insert(*variant_index);
worklist.push(targets.target_for_value(variant_index.as_u32().into()));
}
// The derefer may have inserted a local to access the variant.
// Make sure we keep track of it here.
StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a Rvalue::Operand too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rvalue::Operand(RValue::Copy(_)) would mean that GVN has passed by, and I think GVN would've collapsed all of them into this one: https://github.com/rust-lang/rust/pull/135471/files#diff-2b899097018634da368ec3c772523e10def01aef54b950e07c125a49e374a3e3R332

I can probably assert against it I guess :)

if !is_pinned {
continue;
}
let PlaceRef {
local: SELF_LOCAL,
projection: &[PlaceElem::Field(FIELD_ZERO, _)],
} = deref_place.as_ref()
else {
continue;
};
discr_locals.insert(place.as_local().unwrap());
}
_ => {}
}
}

// Also walk all the successors of this block.
if let Some(term) = &data.terminator {
worklist.extend(term.successors());
}
}

// Filter out the variants that are unreachable.
let TerminatorKind::SwitchInt { targets, .. } =
&mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind
else {
unreachable!();
};
*targets = SwitchTargets::new(
targets
.iter()
.filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))),
targets.otherwise(),
);

// FIXME: We could remove dead variant fields from the coroutine layout, too.
}

pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
- // MIR for `outer::{closure#0}` before SimplifyCfg-final
+ // MIR for `outer::{closure#0}` after SimplifyCfg-final
/* coroutine_layout = CoroutineLayout {
field_tys: {
_0: CoroutineSavedTy {
ty: Coroutine(
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
[
(),
std::future::ResumeTy,
(),
(),
CoroutineWitness(
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
[],
),
(),
],
),
source_info: SourceInfo {
span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16),
scope: scope[0],
},
ignore_for_traits: false,
},
},
variant_fields: {
Unresumed(0): [],
Returned (1): [],
Panicked (2): [],
Suspend0 (3): [_0],
},
storage_conflicts: BitMatrix(1x1) {
(_0, _0),
},
} */

fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _2;
let mut _0: std::task::Poll<()>;
let mut _3: {async fn body of inner()};
let mut _4: {async fn body of inner()};
let mut _5: std::task::Poll<()>;
let mut _6: std::pin::Pin<&mut {async fn body of inner()}>;
let mut _7: &mut {async fn body of inner()};
let mut _8: &mut std::task::Context<'_>;
let mut _9: isize;
let mut _11: ();
let mut _12: &mut std::task::Context<'_>;
let mut _13: u32;
let mut _14: &mut {async fn body of outer()};
scope 1 {
debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()});
let _10: ();
scope 2 {
debug result => const ();
}
}

bb0: {
_14 = copy (_1.0: &mut {async fn body of outer()});
_13 = discriminant((*_14));
- switchInt(move _13) -> [0: bb1, 1: bb15, 3: bb14, otherwise: bb8];
+ switchInt(move _13) -> [0: bb2, 1: bb4, otherwise: bb1];
}

bb1: {
- nop;
- goto -> bb12;
- }
-
- bb2: {
- StorageLive(_3);
- StorageLive(_4);
- _4 = inner() -> [return: bb3, unwind unreachable];
- }
-
- bb3: {
- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind unreachable];
- }
-
- bb4: {
- StorageDead(_4);
- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3;
- goto -> bb5;
- }
-
- bb5: {
- StorageLive(_5);
- StorageLive(_6);
- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()});
- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind unreachable];
- }
-
- bb6: {
- nop;
- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind unreachable];
- }
-
- bb7: {
- StorageDead(_6);
- _9 = discriminant(_5);
- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8];
- }
-
- bb8: {
unreachable;
}

- bb9: {
- StorageDead(_5);
- _0 = const Poll::<()>::Pending;
- StorageDead(_3);
- discriminant((*_14)) = 3;
- return;
- }
-
- bb10: {
- StorageLive(_10);
- nop;
- StorageDead(_10);
- StorageDead(_5);
- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind unreachable];
- }
-
- bb11: {
- StorageDead(_3);
+ bb2: {
_11 = const ();
- goto -> bb13;
+ goto -> bb3;
}

- bb12: {
- _11 = const ();
- goto -> bb13;
- }
-
- bb13: {
+ bb3: {
_0 = Poll::<()>::Ready(const ());
discriminant((*_14)) = 1;
return;
}

- bb14: {
- StorageLive(_3);
- nop;
- goto -> bb5;
- }
-
- bb15: {
- assert(const false, "`async fn` resumed after completion") -> [success: bb15, unwind unreachable];
+ bb4: {
+ assert(const false, "`async fn` resumed after completion") -> [success: bb4, unwind unreachable];
}
}

Loading
Loading