|
27 | 27 | //! naively generate still contains the `_a = ()` write in the unreachable block "after" the
|
28 | 28 | //! return.
|
29 | 29 |
|
| 30 | +use rustc_abi::{FieldIdx, VariantIdx}; |
| 31 | +use rustc_data_structures::fx::FxHashSet; |
| 32 | +use rustc_hir::{CoroutineDesugaring, CoroutineKind}; |
| 33 | +use rustc_index::bit_set::DenseBitSet; |
30 | 34 | use rustc_index::{Idx, IndexSlice, IndexVec};
|
31 | 35 | use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
|
32 | 36 | use rustc_middle::mir::*;
|
@@ -68,6 +72,7 @@ impl SimplifyCfg {
|
68 | 72 |
|
69 | 73 | pub(super) fn simplify_cfg(body: &mut Body<'_>) {
|
70 | 74 | CfgSimplifier::new(body).simplify();
|
| 75 | + remove_dead_coroutine_switch_variants(body); |
71 | 76 | remove_dead_blocks(body);
|
72 | 77 |
|
73 | 78 | // FIXME: Should probably be moved into some kind of pass manager
|
@@ -292,6 +297,149 @@ pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>)
|
292 | 297 | }
|
293 | 298 | }
|
294 | 299 |
|
| 300 | +const SELF_LOCAL: Local = Local::from_u32(1); |
| 301 | +const FIELD_ZERO: FieldIdx = FieldIdx::from_u32(0); |
| 302 | + |
| 303 | +pub(super) fn remove_dead_coroutine_switch_variants(body: &mut Body<'_>) { |
| 304 | + let Some(coroutine_layout) = body.coroutine_layout_raw() else { |
| 305 | + // Not a coroutine; no coroutine variants to remove. |
| 306 | + return; |
| 307 | + }; |
| 308 | + |
| 309 | + let bb0 = &body.basic_blocks[START_BLOCK]; |
| 310 | + |
| 311 | + let is_pinned = match body.coroutine_kind().unwrap() { |
| 312 | + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => false, |
| 313 | + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) |
| 314 | + | CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) |
| 315 | + | CoroutineKind::Coroutine(_) => true, |
| 316 | + }; |
| 317 | + // FIXME: Explain how we read discriminants. |
| 318 | + let mut discr_locals = if is_pinned { |
| 319 | + // FIXME: Mention GVN quirk. |
| 320 | + let StatementKind::Assign(box ( |
| 321 | + place, |
| 322 | + Rvalue::Use(Operand::Copy(deref_place)) | Rvalue::CopyForDeref(deref_place), |
| 323 | + )) = &bb0.statements[0].kind |
| 324 | + else { |
| 325 | + panic!("The first statement of a coroutine is not a self deref"); |
| 326 | + }; |
| 327 | + let PlaceRef { local: SELF_LOCAL, projection: &[PlaceElem::Field(FIELD_ZERO, _)] } = |
| 328 | + deref_place.as_ref() |
| 329 | + else { |
| 330 | + panic!("The first statement of a coroutine is not a self deref"); |
| 331 | + }; |
| 332 | + FxHashSet::from_iter([place.as_local().unwrap()]) |
| 333 | + } else { |
| 334 | + FxHashSet::from_iter([SELF_LOCAL]) |
| 335 | + }; |
| 336 | + |
| 337 | + // The starting block of all coroutines is a switch for the coroutine variants. |
| 338 | + // This is preceded by a read of the discriminant. If we don't find this, then |
| 339 | + // we must have optimized away the switch, so bail. |
| 340 | + let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(discr_local))) = |
| 341 | + &bb0.statements[if is_pinned { 1 } else { 0 }].kind |
| 342 | + else { |
| 343 | + // The following statement is not a discriminant read. |
| 344 | + return; |
| 345 | + }; |
| 346 | + let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = (*discr_local).as_ref() |
| 347 | + else { |
| 348 | + // We expect the discriminant to have read `&mut self`, |
| 349 | + // so we expect the place to be a deref. |
| 350 | + return; |
| 351 | + }; |
| 352 | + if !discr_locals.contains(&deref_local) { |
| 353 | + // The place being read isn't `_1` (self) or a `Derefer`-inserted local. |
| 354 | + return; |
| 355 | + } |
| 356 | + let TerminatorKind::SwitchInt { discr: Operand::Move(place), targets } = &bb0.terminator().kind |
| 357 | + else { |
| 358 | + // When panic=abort, we may end up folding away the other variants of the |
| 359 | + // coroutine, and end up with ths `SwitchInt` getting replaced. |
| 360 | + return; |
| 361 | + }; |
| 362 | + if place != discr_place { |
| 363 | + // Make sure we don't try to match on some other `SwitchInt`; we should be |
| 364 | + // matching on the discriminant we just read. |
| 365 | + return; |
| 366 | + } |
| 367 | + |
| 368 | + let mut visited = DenseBitSet::new_empty(body.basic_blocks.len()); |
| 369 | + let mut worklist = vec![]; |
| 370 | + let mut visited_variants = DenseBitSet::new_empty(coroutine_layout.variant_fields.len()); |
| 371 | + |
| 372 | + // Insert unresumed (initial), returned, panicked variants. |
| 373 | + // We treat these as always reachable. |
| 374 | + visited_variants.insert(VariantIdx::from_usize(0)); |
| 375 | + visited_variants.insert(VariantIdx::from_usize(1)); |
| 376 | + visited_variants.insert(VariantIdx::from_usize(2)); |
| 377 | + worklist.push(targets.target_for_value(0)); |
| 378 | + worklist.push(targets.target_for_value(1)); |
| 379 | + worklist.push(targets.target_for_value(2)); |
| 380 | + |
| 381 | + // Walk all of the reachable variant blocks. |
| 382 | + while let Some(block) = worklist.pop() { |
| 383 | + if !visited.insert(block) { |
| 384 | + continue; |
| 385 | + } |
| 386 | + |
| 387 | + let data = &body.basic_blocks[block]; |
| 388 | + for stmt in &data.statements { |
| 389 | + match &stmt.kind { |
| 390 | + // If we see a `SetDiscriminant` statement for our coroutine, |
| 391 | + // mark that variant as reachable and add it to the worklist. |
| 392 | + StatementKind::SetDiscriminant { place, variant_index } => { |
| 393 | + let PlaceRef { local: deref_local, projection: &[PlaceElem::Deref] } = |
| 394 | + (**place).as_ref() |
| 395 | + else { |
| 396 | + continue; |
| 397 | + }; |
| 398 | + if !discr_locals.contains(&deref_local) { |
| 399 | + continue; |
| 400 | + } |
| 401 | + visited_variants.insert(*variant_index); |
| 402 | + worklist.push(targets.target_for_value(variant_index.as_u32().into())); |
| 403 | + } |
| 404 | + // The derefer may have inserted a local to access the variant. |
| 405 | + // Make sure we keep track of it here. |
| 406 | + StatementKind::Assign(box (place, Rvalue::CopyForDeref(deref_place))) => { |
| 407 | + if !is_pinned { |
| 408 | + continue; |
| 409 | + } |
| 410 | + let PlaceRef { |
| 411 | + local: SELF_LOCAL, |
| 412 | + projection: &[PlaceElem::Field(FIELD_ZERO, _)], |
| 413 | + } = deref_place.as_ref() |
| 414 | + else { |
| 415 | + continue; |
| 416 | + }; |
| 417 | + discr_locals.insert(place.as_local().unwrap()); |
| 418 | + } |
| 419 | + _ => {} |
| 420 | + } |
| 421 | + } |
| 422 | + |
| 423 | + // Also walk all the successors of this block. |
| 424 | + if let Some(term) = &data.terminator { |
| 425 | + worklist.extend(term.successors()); |
| 426 | + } |
| 427 | + } |
| 428 | + |
| 429 | + let TerminatorKind::SwitchInt { targets, .. } = |
| 430 | + &mut body.basic_blocks.as_mut()[START_BLOCK].terminator_mut().kind |
| 431 | + else { |
| 432 | + unreachable!(); |
| 433 | + }; |
| 434 | + // Filter out the variants that are unreachable. |
| 435 | + *targets = SwitchTargets::new( |
| 436 | + targets |
| 437 | + .iter() |
| 438 | + .filter(|(idx, _)| visited_variants.contains(VariantIdx::from_u32(*idx as u32))), |
| 439 | + targets.otherwise(), |
| 440 | + ); |
| 441 | +} |
| 442 | + |
295 | 443 | pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
|
296 | 444 | let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
|
297 | 445 | // CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
|
|
0 commit comments