|
| 1 | +use crate::transform::{simplify, MirPass, MirSource}; |
| 2 | +use rustc_middle::mir::*; |
| 3 | +use rustc_middle::ty::TyCtxt; |
| 4 | + |
| 5 | +pub struct MatchBranchSimplification; |
| 6 | + |
| 7 | +// What's the intent of this pass? |
| 8 | +// If one block is found that switches between blocks which both go to the same place |
| 9 | +// AND both of these blocks set a similar const in their -> |
| 10 | +// condense into 1 block based on discriminant AND goto the destination afterwards |
| 11 | + |
| 12 | +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { |
| 13 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) { |
| 14 | + let param_env = tcx.param_env(src.def_id()); |
| 15 | + let mut did_remove_blocks = false; |
| 16 | + let bbs = body.basic_blocks_mut(); |
| 17 | + 'outer: for bb_idx in bbs.indices() { |
| 18 | + let (discr, val, switch_ty, targets) = match bbs[bb_idx].terminator().kind { |
| 19 | + TerminatorKind::SwitchInt { |
| 20 | + discr: Operand::Move(ref place), |
| 21 | + switch_ty, |
| 22 | + ref targets, |
| 23 | + ref values, |
| 24 | + .. |
| 25 | + } if targets.len() == 2 && values.len() == 1 => { |
| 26 | + (place.clone(), values[0], switch_ty, targets) |
| 27 | + } |
| 28 | + _ => continue, |
| 29 | + }; |
| 30 | + let (first, rest) = if let ([first], rest) = targets.split_at(1) { |
| 31 | + (*first, rest) |
| 32 | + } else { |
| 33 | + unreachable!(); |
| 34 | + }; |
| 35 | + let first_dest = bbs[first].terminator().kind.clone(); |
| 36 | + let same_destinations = rest |
| 37 | + .iter() |
| 38 | + .map(|target| &bbs[*target].terminator().kind) |
| 39 | + .all(|t_kind| t_kind == &first_dest); |
| 40 | + if !same_destinations { |
| 41 | + continue; |
| 42 | + } |
| 43 | + let first_stmts = &bbs[first].statements; |
| 44 | + for s in first_stmts.iter() { |
| 45 | + match &s.kind { |
| 46 | + StatementKind::Assign(box (_, rhs)) => { |
| 47 | + if let Rvalue::Use(Operand::Constant(_)) = rhs { |
| 48 | + } else { |
| 49 | + continue 'outer; |
| 50 | + } |
| 51 | + } |
| 52 | + _ => continue 'outer, |
| 53 | + } |
| 54 | + } |
| 55 | + for target in rest.iter() { |
| 56 | + for s in bbs[*target].statements.iter() { |
| 57 | + if let StatementKind::Assign(box (ref lhs, rhs)) = &s.kind { |
| 58 | + if let Rvalue::Use(Operand::Constant(_)) = rhs { |
| 59 | + let has_matching_assn = first_stmts |
| 60 | + .iter() |
| 61 | + .find(|s| { |
| 62 | + if let StatementKind::Assign(box (lhs_f, _)) = &s.kind { |
| 63 | + lhs_f == lhs |
| 64 | + } else { |
| 65 | + false |
| 66 | + } |
| 67 | + }) |
| 68 | + .is_some(); |
| 69 | + if has_matching_assn { |
| 70 | + continue; |
| 71 | + } |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + continue 'outer; |
| 76 | + } |
| 77 | + } |
| 78 | + let (first_block, to_add) = bbs.pick2_mut(first, bb_idx); |
| 79 | + let new_stmts = first_block.statements.iter().cloned().map(|mut s| { |
| 80 | + if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind { |
| 81 | + let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size; |
| 82 | + let const_cmp = Operand::const_from_scalar( |
| 83 | + tcx, |
| 84 | + switch_ty, |
| 85 | + crate::interpret::Scalar::from_uint(val, size), |
| 86 | + rustc_span::DUMMY_SP, |
| 87 | + ); |
| 88 | + *rhs = Rvalue::BinaryOp(BinOp::Eq, Operand::Move(discr), const_cmp); |
| 89 | + } else { |
| 90 | + unreachable!() |
| 91 | + } |
| 92 | + s |
| 93 | + }); |
| 94 | + to_add.statements.extend(new_stmts); |
| 95 | + to_add.terminator_mut().kind = first_dest; |
| 96 | + did_remove_blocks = true; |
| 97 | + } |
| 98 | + if did_remove_blocks { |
| 99 | + simplify::remove_dead_blocks(body); |
| 100 | + } |
| 101 | + } |
| 102 | +} |
0 commit comments