-
Notifications
You must be signed in to change notification settings - Fork 13.3k
Prune unreachable variants of coroutines #135471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,10 @@ | |
//! naively generate still contains the `_a = ()` write in the unreachable block "after" the | ||
//! return. | ||
|
||
use rustc_abi::{FieldIdx, VariantIdx}; | ||
use rustc_data_structures::fx::FxHashSet; | ||
use rustc_hir::{CoroutineDesugaring, CoroutineKind}; | ||
use rustc_index::bit_set::DenseBitSet; | ||
use rustc_index::{Idx, IndexSlice, IndexVec}; | ||
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; | ||
use rustc_middle::mir::*; | ||
|
@@ -68,6 +72,7 @@ impl SimplifyCfg { | |
|
||
pub(super) fn simplify_cfg(body: &mut Body<'_>) { | ||
CfgSimplifier::new(body).simplify(); | ||
remove_dead_coroutine_switch_variants(body); | ||
remove_dead_blocks(body); | ||
|
||
// FIXME: Should probably be moved into some kind of pass manager | ||
|
@@ -292,6 +297,168 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) | |
} | ||
} | ||
|
||
const SELF_LOCAL: Local = Local::from_u32(1); | ||
const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0); | ||
|
||
pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) { | ||
let Some(coroutine_layout) = body.coroutine_layout_raw() else { | ||
// Not a coroutine; no coroutine variants to remove. | ||
return; | ||
}; | ||
|
||
let bb0 = &body.basic_blocks[START_BLOCK]; | ||
|
||
let is_pinned = match body.coroutine_kind().unwrap() { | ||
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false, | ||
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | ||
| CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) | ||
| CoroutineKind::Coroutine(_) => true, | ||
}; | ||
// This is essentially our off-brand `Underefer`. This stores the set of locals | ||
// that we have determined to contain references to the coroutine discriminant. | ||
// If the self type is not pinned, this is just going to be `_1`. However, if | ||
// the self type is pinned, the derefer will emit statements of the form: | ||
// _x = CopyForDeref (_1.0); | ||
// We'll store the local for `_x` so that we can later detect discriminant stores | ||
// of the form: | ||
// Discriminant((*_x)) = ... | ||
// which correspond to reachable variants of the coroutine. | ||
let mut discr_locals = if is_pinned { | ||
let Some(stmt) = bb0.statements.get(0) else { | ||
// The coroutine body may have been turned into a single `unreachable`. | ||
return; | ||
}; | ||
// We match `CopyForDeref` (which is what gets emitted from the state transform | ||
// pass), but also we match *regular* `Copy`, which is what GVN may optimize it to. | ||
let StatementKind::Assign(box ( | ||
place, | ||
Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place), | ||
)) = &stmt.kind | ||
else { | ||
panic!("The first statement of a coroutine is not a self deref"); | ||
}; | ||
let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } = | ||
deref_place.as_ref() | ||
else { | ||
panic!("The first statement of a coroutine is not a self deref"); | ||
}; | ||
FxHashSet::from_iter([place.as_local().unwrap()]) | ||
} else { | ||
FxHashSet::from_iter([SELF_LOCAL]) | ||
}; | ||
|
||
// The starting block of all coroutines is a switch for the coroutine variants. | ||
// This is preceded by a read of the discriminant. If we don't find this, then | ||
// we must have optimized away the switch, so bail. | ||
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) = | ||
&bb0.statements[if is_pinned { 1 } else { 0 }].kind | ||
else { | ||
// The following statement is not a discriminant read. We may have | ||
// optimized it out, so bail gracefully. | ||
return; | ||
}; | ||
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref() | ||
else { | ||
// We expect the discriminant to have read `&mut self`, | ||
// so we expect the place to be a deref. If we didn't, then | ||
// it may have been optimized out, so bail gracefully. | ||
return; | ||
}; | ||
if !discr_locals.contains(&deref_local) { | ||
// The place being read isn't `_1` (self) or a `Derefer`-inserted local. | ||
// It may have been optimized out, so bail gracefully. | ||
return; | ||
} | ||
let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind | ||
else { | ||
// When panic=abort, we may end up folding away the other variants of the | ||
// coroutine, and end up with ths `SwitchInt` getting replaced. In this | ||
// case, there's no need to do this optimization, so bail gracefully. | ||
return; | ||
}; | ||
if place != discr_place { | ||
// Make sure we don't try to match on some other `SwitchInt`; we should be | ||
// matching on the discriminant we just read. | ||
return; | ||
} | ||
|
||
let mut visited = DenseBitSet::new_empty(body.basic_blocks.len()); | ||
let mut worklist = vec![]; | ||
let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len()); | ||
|
||
// Insert unresumed (initial), returned, panicked variants. | ||
// We treat these as always reachable. | ||
visited_variants.insert(VariantIdx::from_usize(0)); | ||
visited_variants.insert(VariantIdx::from_usize(1)); | ||
visited_variants.insert(VariantIdx::from_usize(2)); | ||
worklist.push(targets.target_for_value(0)); | ||
worklist.push(targets.target_for_value(1)); | ||
worklist.push(targets.target_for_value(2)); | ||
|
||
// Walk all of the reachable variant blocks. | ||
while let Some(block) = worklist.pop() { | ||
if !visited.insert(block) { | ||
continue; | ||
} | ||
|
||
let data = &body.basic_blocks[block]; | ||
for stmt in &data.statements { | ||
match &stmt.kind { | ||
// If we see a `SetDiscriminant` statement for our coroutine, | ||
// mark that variant as reachable and add it to the worklist. | ||
StatementKind::SetDiscriminant { place, variant_index } => { | ||
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = | ||
(**place).as_ref() | ||
else { | ||
continue; | ||
}; | ||
if !discr_locals.contains(&deref_local) { | ||
continue; | ||
} | ||
visited_variants.insert(*variant_index); | ||
worklist.push(targets.target_for_value(variant_index.as_u32().into())); | ||
} | ||
// The derefer may have inserted a local to access the variant. | ||
// Make sure we keep track of it here. | ||
StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be a Rvalue::Operand too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I can probably assert against it I guess :) |
||
if !is_pinned { | ||
continue; | ||
} | ||
let PlaceRef { | ||
local: SELF_LOCAL, | ||
projection: &[PlaceElem::Field(FIELD_ZERO, _)], | ||
} = deref_place.as_ref() | ||
else { | ||
continue; | ||
}; | ||
discr_locals.insert(place.as_local().unwrap()); | ||
} | ||
_ => {} | ||
} | ||
} | ||
|
||
// Also walk all the successors of this block. | ||
if let Some(term) = &data.terminator { | ||
worklist.extend(term.successors()); | ||
} | ||
} | ||
|
||
// Filter out the variants that are unreachable. | ||
let TerminatorKind::SwitchInt { targets, .. } = | ||
&mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind | ||
else { | ||
unreachable!(); | ||
}; | ||
*targets = SwitchTargets::new( | ||
targets | ||
.iter() | ||
.filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))), | ||
targets.otherwise(), | ||
); | ||
|
||
// FIXME: We could remove dead variant fields from the coroutine layout, too. | ||
} | ||
|
||
pub(super) fn remove_dead_blocks(body: &mut Body<'_>) { | ||
let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| { | ||
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
- // MIR for `outer::{closure#0}` before SimplifyCfg-final | ||
+ // MIR for `outer::{closure#0}` after SimplifyCfg-final | ||
/* coroutine_layout = CoroutineLayout { | ||
field_tys: { | ||
_0: CoroutineSavedTy { | ||
ty: Coroutine( | ||
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}), | ||
[ | ||
(), | ||
std::future::ResumeTy, | ||
(), | ||
(), | ||
CoroutineWitness( | ||
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}), | ||
[], | ||
), | ||
(), | ||
], | ||
), | ||
source_info: SourceInfo { | ||
span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16), | ||
scope: scope[0], | ||
}, | ||
ignore_for_traits: false, | ||
}, | ||
}, | ||
variant_fields: { | ||
Unresumed(0): [], | ||
Returned (1): [], | ||
Panicked (2): [], | ||
Suspend0 (3): [_0], | ||
}, | ||
storage_conflicts: BitMatrix(1x1) { | ||
(_0, _0), | ||
}, | ||
} */ | ||
|
||
fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> { | ||
debug _task_context => _2; | ||
let mut _0: std::task::Poll<()>; | ||
let mut _3: {async fn body of inner()}; | ||
let mut _4: {async fn body of inner()}; | ||
let mut _5: std::task::Poll<()>; | ||
let mut _6: std::pin::Pin<&mut {async fn body of inner()}>; | ||
let mut _7: &mut {async fn body of inner()}; | ||
let mut _8: &mut std::task::Context<'_>; | ||
let mut _9: isize; | ||
let mut _11: (); | ||
let mut _12: &mut std::task::Context<'_>; | ||
let mut _13: u32; | ||
let mut _14: &mut {async fn body of outer()}; | ||
scope 1 { | ||
debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()}); | ||
let _10: (); | ||
scope 2 { | ||
debug result => const (); | ||
} | ||
} | ||
|
||
bb0: { | ||
_14 = copy (_1.0: &mut {async fn body of outer()}); | ||
_13 = discriminant((*_14)); | ||
- switchInt(move _13) -> [0: bb1, 1: bb15, 3: bb14, otherwise: bb8]; | ||
+ switchInt(move _13) -> [0: bb2, 1: bb4, otherwise: bb1]; | ||
} | ||
|
||
bb1: { | ||
- nop; | ||
- goto -> bb12; | ||
- } | ||
- | ||
- bb2: { | ||
- StorageLive(_3); | ||
- StorageLive(_4); | ||
- _4 = inner() -> [return: bb3, unwind unreachable]; | ||
- } | ||
- | ||
- bb3: { | ||
- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind unreachable]; | ||
- } | ||
- | ||
- bb4: { | ||
- StorageDead(_4); | ||
- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3; | ||
- goto -> bb5; | ||
- } | ||
- | ||
- bb5: { | ||
- StorageLive(_5); | ||
- StorageLive(_6); | ||
- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()}); | ||
- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind unreachable]; | ||
- } | ||
- | ||
- bb6: { | ||
- nop; | ||
- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind unreachable]; | ||
- } | ||
- | ||
- bb7: { | ||
- StorageDead(_6); | ||
- _9 = discriminant(_5); | ||
- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8]; | ||
- } | ||
- | ||
- bb8: { | ||
unreachable; | ||
} | ||
|
||
- bb9: { | ||
- StorageDead(_5); | ||
- _0 = const Poll::<()>::Pending; | ||
- StorageDead(_3); | ||
- discriminant((*_14)) = 3; | ||
- return; | ||
- } | ||
- | ||
- bb10: { | ||
- StorageLive(_10); | ||
- nop; | ||
- StorageDead(_10); | ||
- StorageDead(_5); | ||
- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind unreachable]; | ||
- } | ||
- | ||
- bb11: { | ||
- StorageDead(_3); | ||
+ bb2: { | ||
_11 = const (); | ||
- goto -> bb13; | ||
+ goto -> bb3; | ||
} | ||
|
||
- bb12: { | ||
- _11 = const (); | ||
- goto -> bb13; | ||
- } | ||
- | ||
- bb13: { | ||
+ bb3: { | ||
_0 = Poll::<()>::Ready(const ()); | ||
discriminant((*_14)) = 1; | ||
return; | ||
} | ||
|
||
- bb14: { | ||
- StorageLive(_3); | ||
- nop; | ||
- goto -> bb5; | ||
- } | ||
- | ||
- bb15: { | ||
- assert(const false, "`async fn` resumed after completion") -> [success: bb15, unwind unreachable]; | ||
+ bb4: { | ||
+ assert(const false, "`async fn` resumed after completion") -> [success: bb4, unwind unreachable]; | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we optimize out too much, we may not even have a statement left.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but we already check that there's at least 1 statement in https://github.com/rust-lang/rust/pull/135471/files#diff-2b899097018634da368ec3c772523e10def01aef54b950e07c125a49e374a3e3R329. I don't know if there will ever be a case where that succeeds but this fails.