Skip to content

Commit a7a5903

Browse files
Prune unreachable variants of coroutines
1 parent 2ae9916 commit a7a5903

4 files changed

+515
-0
lines changed

compiler/rustc_mir_transform/src/simplify.rs

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
2828
//! return.
2929
30+
use rustc_abi::{FieldIdx, VariantIdx};
31+
use rustc_data_structures::fx::FxHashSet;
32+
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
33+
use rustc_index::bit_set::DenseBitSet;
3034
use rustc_index::{Idx, IndexSlice, IndexVec};
3135
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
3236
use rustc_middle::mir::*;
@@ -68,6 +72,7 @@ impl SimplifyCfg {
6872

6973
pub(super) fn simplify_cfg(body: &mut Body<'_>) {
7074
CfgSimplifier::new(body).simplify();
75+
remove_dead_coroutine_switch_variants(body);
7176
remove_dead_blocks(body);
7277

7378
// FIXME: Should probably be moved into some kind of pass manager
@@ -292,6 +297,149 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>)
292297
}
293298
}
294299

300+
const SELF_LOCAL: Local = Local::from_u32(1);
301+
const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0);
302+
303+
pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) {
304+
let Some(coroutine_layout) = body.coroutine_layout_raw() else {
305+
// Not a coroutine; no coroutine variants to remove.
306+
return;
307+
};
308+
309+
let bb0 = &body.basic_blocks[START_BLOCK];
310+
311+
let is_pinned = match body.coroutine_kind().unwrap() {
312+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false,
313+
CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
314+
| CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
315+
| CoroutineKind::Coroutine(_) => true,
316+
};
317+
// FIXME: Explain how we read discriminants.
318+
let mut discr_locals = if is_pinned {
319+
// FIXME: Mention GVN quirk.
320+
let StatementKind::Assign(box (
321+
place,
322+
Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place),
323+
)) = &bb0.statements[0].kind
324+
else {
325+
panic!("The first statement of a coroutine is not a self deref");
326+
};
327+
let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } =
328+
deref_place.as_ref()
329+
else {
330+
panic!("The first statement of a coroutine is not a self deref");
331+
};
332+
FxHashSet::from_iter([place.as_local().unwrap()])
333+
} else {
334+
FxHashSet::from_iter([SELF_LOCAL])
335+
};
336+
337+
// The starting block of all coroutines is a switch for the coroutine variants.
338+
// This is preceded by a read of the discriminant. If we don't find this, then
339+
// we must have optimized away the switch, so bail.
340+
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) =
341+
&bb0.statements[if is_pinned { 1 } else { 0 }].kind
342+
else {
343+
// The following statement is not a discriminant read.
344+
return;
345+
};
346+
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref()
347+
else {
348+
// We expect the discriminant to have read `&mut self`,
349+
// so we expect the place to be a deref.
350+
return;
351+
};
352+
if !discr_locals.contains(&deref_local) {
353+
// The place being read isn't `_1` (self) or a `Derefer`-inserted local.
354+
return;
355+
}
356+
let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind
357+
else {
358+
// When panic=abort, we may end up folding away the other variants of the
359+
// coroutine, and end up with ths `SwitchInt` getting replaced.
360+
return;
361+
};
362+
if place != discr_place {
363+
// Make sure we don't try to match on some other `SwitchInt`; we should be
364+
// matching on the discriminant we just read.
365+
return;
366+
}
367+
368+
let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
369+
let mut worklist = vec![];
370+
let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len());
371+
372+
// Insert unresumed (initial), returned, panicked variants.
373+
// We treat these as always reachable.
374+
visited_variants.insert(VariantIdx::from_usize(0));
375+
visited_variants.insert(VariantIdx::from_usize(1));
376+
visited_variants.insert(VariantIdx::from_usize(2));
377+
worklist.push(targets.target_for_value(0));
378+
worklist.push(targets.target_for_value(1));
379+
worklist.push(targets.target_for_value(2));
380+
381+
// Walk all of the reachable variant blocks.
382+
while let Some(block) = worklist.pop() {
383+
if !visited.insert(block) {
384+
continue;
385+
}
386+
387+
let data = &body.basic_blocks[block];
388+
for stmt in &data.statements {
389+
match &stmt.kind {
390+
// If we see a `SetDiscriminant` statement for our coroutine,
391+
// mark that variant as reachable and add it to the worklist.
392+
StatementKind::SetDiscriminant { place, variant_index } => {
393+
let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } =
394+
(**place).as_ref()
395+
else {
396+
continue;
397+
};
398+
if !discr_locals.contains(&deref_local) {
399+
continue;
400+
}
401+
visited_variants.insert(*variant_index);
402+
worklist.push(targets.target_for_value(variant_index.as_u32().into()));
403+
}
404+
// The derefer may have inserted a local to access the variant.
405+
// Make sure we keep track of it here.
406+
StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => {
407+
if !is_pinned {
408+
continue;
409+
}
410+
let PlaceRef {
411+
local: SELF_LOCAL,
412+
projection: &[PlaceElem::Field(FIELD_ZERO, _)],
413+
} = deref_place.as_ref()
414+
else {
415+
continue;
416+
};
417+
discr_locals.insert(place.as_local().unwrap());
418+
}
419+
_ => {}
420+
}
421+
}
422+
423+
// Also walk all the successors of this block.
424+
if let Some(term) = &data.terminator {
425+
worklist.extend(term.successors());
426+
}
427+
}
428+
429+
let TerminatorKind::SwitchInt { targets, .. } =
430+
&mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind
431+
else {
432+
unreachable!();
433+
};
434+
// Filter out the variants that are unreachable.
435+
*targets = SwitchTargets::new(
436+
targets
437+
.iter()
438+
.filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))),
439+
targets.otherwise(),
440+
);
441+
}
442+
295443
pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
296444
let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
297445
// CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
- // MIR for `outer::{closure#0}` before SimplifyCfg-final
2+
+ // MIR for `outer::{closure#0}` after SimplifyCfg-final
3+
/* coroutine_layout = CoroutineLayout {
4+
field_tys: {
5+
_0: CoroutineSavedTy {
6+
ty: Coroutine(
7+
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
8+
[
9+
(),
10+
std::future::ResumeTy,
11+
(),
12+
(),
13+
CoroutineWitness(
14+
DefId(0:5 ~ coroutine_dead_variants[61c4]::inner::{closure#0}),
15+
[],
16+
),
17+
(),
18+
],
19+
),
20+
source_info: SourceInfo {
21+
span: $DIR/coroutine_dead_variants.rs:13:9: 13:22 (#16),
22+
scope: scope[0],
23+
},
24+
ignore_for_traits: false,
25+
},
26+
},
27+
variant_fields: {
28+
Unresumed(0): [],
29+
Returned (1): [],
30+
Panicked (2): [],
31+
Suspend0 (3): [_0],
32+
},
33+
storage_conflicts: BitMatrix(1x1) {
34+
(_0, _0),
35+
},
36+
} */
37+
38+
fn outer::{closure#0}(_1: Pin<&mut {async fn body of outer()}>, _2: &mut Context<'_>) -> Poll<()> {
39+
debug _task_context => _2;
40+
let mut _0: std::task::Poll<()>;
41+
let mut _3: {async fn body of inner()};
42+
let mut _4: {async fn body of inner()};
43+
let mut _5: std::task::Poll<()>;
44+
let mut _6: std::pin::Pin<&mut {async fn body of inner()}>;
45+
let mut _7: &mut {async fn body of inner()};
46+
let mut _8: &mut std::task::Context<'_>;
47+
let mut _9: isize;
48+
let mut _11: ();
49+
let mut _12: &mut std::task::Context<'_>;
50+
let mut _13: u32;
51+
let mut _14: &mut {async fn body of outer()};
52+
scope 1 {
53+
debug __awaitee => (((*(_1.0: &mut {async fn body of outer()})) as variant#3).0: {async fn body of inner()});
54+
let _10: ();
55+
scope 2 {
56+
debug result => const ();
57+
}
58+
}
59+
60+
bb0: {
61+
_14 = copy (_1.0: &mut {async fn body of outer()});
62+
_13 = discriminant((*_14));
63+
- switchInt(move _13) -> [0: bb1, 1: bb15, 3: bb14, otherwise: bb8];
64+
+ switchInt(move _13) -> [0: bb2, 1: bb4, otherwise: bb1];
65+
}
66+
67+
bb1: {
68+
- nop;
69+
- goto -> bb12;
70+
- }
71+
-
72+
- bb2: {
73+
- StorageLive(_3);
74+
- StorageLive(_4);
75+
- _4 = inner() -> [return: bb3, unwind unreachable];
76+
- }
77+
-
78+
- bb3: {
79+
- _3 = <{async fn body of inner()} as IntoFuture>::into_future(move _4) -> [return: bb4, unwind unreachable];
80+
- }
81+
-
82+
- bb4: {
83+
- StorageDead(_4);
84+
- (((*_14) as variant#3).0: {async fn body of inner()}) = move _3;
85+
- goto -> bb5;
86+
- }
87+
-
88+
- bb5: {
89+
- StorageLive(_5);
90+
- StorageLive(_6);
91+
- _7 = &mut (((*_14) as variant#3).0: {async fn body of inner()});
92+
- _6 = Pin::<&mut {async fn body of inner()}>::new_unchecked(copy _7) -> [return: bb6, unwind unreachable];
93+
- }
94+
-
95+
- bb6: {
96+
- nop;
97+
- _5 = <{async fn body of inner()} as Future>::poll(move _6, copy _2) -> [return: bb7, unwind unreachable];
98+
- }
99+
-
100+
- bb7: {
101+
- StorageDead(_6);
102+
- _9 = discriminant(_5);
103+
- switchInt(move _9) -> [0: bb10, 1: bb9, otherwise: bb8];
104+
- }
105+
-
106+
- bb8: {
107+
unreachable;
108+
}
109+
110+
- bb9: {
111+
- StorageDead(_5);
112+
- _0 = const Poll::<()>::Pending;
113+
- StorageDead(_3);
114+
- discriminant((*_14)) = 3;
115+
- return;
116+
- }
117+
-
118+
- bb10: {
119+
- StorageLive(_10);
120+
- nop;
121+
- StorageDead(_10);
122+
- StorageDead(_5);
123+
- drop((((*_14) as variant#3).0: {async fn body of inner()})) -> [return: bb11, unwind unreachable];
124+
- }
125+
-
126+
- bb11: {
127+
- StorageDead(_3);
128+
+ bb2: {
129+
_11 = const ();
130+
- goto -> bb13;
131+
+ goto -> bb3;
132+
}
133+
134+
- bb12: {
135+
- _11 = const ();
136+
- goto -> bb13;
137+
- }
138+
-
139+
- bb13: {
140+
+ bb3: {
141+
_0 = Poll::<()>::Ready(const ());
142+
discriminant((*_14)) = 1;
143+
return;
144+
}
145+
146+
- bb14: {
147+
- StorageLive(_3);
148+
- nop;
149+
- goto -> bb5;
150+
- }
151+
-
152+
- bb15: {
153+
- assert(const false, "`async fn` resumed after completion") -> [success: bb15, unwind unreachable];
154+
+ bb4: {
155+
+ assert(const false, "`async fn` resumed after completion") -> [success: bb4, unwind unreachable];
156+
}
157+
}
158+

0 commit comments

Comments
 (0)