|
| 1 | +use rustc_index::IndexVec; |
1 | 2 | use rustc_middle::mir::*;
|
2 |
| -use rustc_middle::ty::TyCtxt; |
| 3 | +use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; |
3 | 4 | use std::iter;
|
4 | 5 |
|
5 | 6 | use super::simplify::simplify_cfg;
|
6 | 7 |
|
7 | 8 | pub struct MatchBranchSimplification;
|
8 | 9 |
|
| 10 | +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { |
| 11 | + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
| 12 | + sess.mir_opt_level() >= 1 |
| 13 | + } |
| 14 | + |
| 15 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 16 | + let def_id = body.source.def_id(); |
| 17 | + let param_env = tcx.param_env_reveal_all_normalized(def_id); |
| 18 | + |
| 19 | + let bbs = body.basic_blocks.as_mut(); |
| 20 | + let mut should_cleanup = false; |
| 21 | + for bb_idx in bbs.indices() { |
| 22 | + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { |
| 23 | + continue; |
| 24 | + } |
| 25 | + |
| 26 | + match bbs[bb_idx].terminator().kind { |
| 27 | + TerminatorKind::SwitchInt { |
| 28 | + discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), |
| 29 | + ref targets, |
| 30 | + .. |
| 31 | + // We require that the possible target blocks don't contain this block. |
| 32 | + } if !targets.all_targets().contains(&bb_idx) => {} |
| 33 | + // Only optimize switch int statements |
| 34 | + _ => continue, |
| 35 | + }; |
| 36 | + |
| 37 | + if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) { |
| 38 | + should_cleanup = true; |
| 39 | + continue; |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + if should_cleanup { |
| 44 | + simplify_cfg(body); |
| 45 | + } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +trait SimplifyMatch<'tcx> { |
| 50 | + fn simplify( |
| 51 | + &self, |
| 52 | + tcx: TyCtxt<'tcx>, |
| 53 | + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, |
| 54 | + bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 55 | + switch_bb_idx: BasicBlock, |
| 56 | + param_env: ParamEnv<'tcx>, |
| 57 | + ) -> bool { |
| 58 | + let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { |
| 59 | + TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), |
| 60 | + _ => unreachable!(), |
| 61 | + }; |
| 62 | + |
| 63 | + if !self.can_simplify(tcx, targets, param_env, bbs) { |
| 64 | + return false; |
| 65 | + } |
| 66 | + |
| 67 | + // Take ownership of items now that we know we can optimize. |
| 68 | + let discr = discr.clone(); |
| 69 | + let discr_ty = discr.ty(local_decls, tcx); |
| 70 | + |
| 71 | + // Introduce a temporary for the discriminant value. |
| 72 | + let source_info = bbs[switch_bb_idx].terminator().source_info; |
| 73 | + let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); |
| 74 | + |
| 75 | + // We already checked that first and second are different blocks, |
| 76 | + // and bb_idx has a different terminator from both of them. |
| 77 | + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); |
| 78 | + let (_, first) = targets.iter().next().unwrap(); |
| 79 | + let (from, first) = bbs.pick2_mut(switch_bb_idx, first); |
| 80 | + from.statements |
| 81 | + .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); |
| 82 | + from.statements.push(Statement { |
| 83 | + source_info, |
| 84 | + kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))), |
| 85 | + }); |
| 86 | + from.statements.extend(new_stmts); |
| 87 | + from.statements |
| 88 | + .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); |
| 89 | + from.terminator_mut().kind = first.terminator().kind.clone(); |
| 90 | + true |
| 91 | + } |
| 92 | + |
| 93 | + fn can_simplify( |
| 94 | + &self, |
| 95 | + tcx: TyCtxt<'tcx>, |
| 96 | + targets: &SwitchTargets, |
| 97 | + param_env: ParamEnv<'tcx>, |
| 98 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 99 | + ) -> bool; |
| 100 | + |
| 101 | + fn new_stmts( |
| 102 | + &self, |
| 103 | + tcx: TyCtxt<'tcx>, |
| 104 | + targets: &SwitchTargets, |
| 105 | + param_env: ParamEnv<'tcx>, |
| 106 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 107 | + discr_local: Local, |
| 108 | + discr_ty: Ty<'tcx>, |
| 109 | + ) -> Vec<Statement<'tcx>>; |
| 110 | +} |
| 111 | + |
| 112 | +struct SimplifyToIf; |
| 113 | + |
9 | 114 | /// If a source block is found that switches between two blocks that are exactly
|
10 | 115 | /// the same modulo const bool assignments (e.g., one assigns true another false
|
11 | 116 | /// to the same place), merge a target block statements into the source block,
|
@@ -37,144 +142,111 @@ pub struct MatchBranchSimplification;
|
37 | 142 | /// goto -> bb3;
|
38 | 143 | /// }
|
39 | 144 | /// ```
|
| 145 | +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { |
| 146 | + fn can_simplify( |
| 147 | + &self, |
| 148 | + tcx: TyCtxt<'tcx>, |
| 149 | + targets: &SwitchTargets, |
| 150 | + param_env: ParamEnv<'tcx>, |
| 151 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 152 | + ) -> bool { |
| 153 | + if targets.iter().len() != 1 { |
| 154 | + return false; |
| 155 | + } |
| 156 | + // We require that the possible target blocks all be distinct. |
| 157 | + let (_, first) = targets.iter().next().unwrap(); |
| 158 | + let second = targets.otherwise(); |
| 159 | + if first == second { |
| 160 | + return false; |
| 161 | + } |
| 162 | + // Check that destinations are identical, and if not, then don't optimize this block |
| 163 | + if bbs[first].terminator().kind != bbs[second].terminator().kind { |
| 164 | + return false; |
| 165 | + } |
40 | 166 |
|
41 |
| -impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { |
42 |
| - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
43 |
| - sess.mir_opt_level() >= 1 |
44 |
| - } |
45 |
| - |
46 |
| - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
47 |
| - let def_id = body.source.def_id(); |
48 |
| - let param_env = tcx.param_env_reveal_all_normalized(def_id); |
49 |
| - |
50 |
| - let bbs = body.basic_blocks.as_mut(); |
51 |
| - let mut should_cleanup = false; |
52 |
| - 'outer: for bb_idx in bbs.indices() { |
53 |
| - if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { |
54 |
| - continue; |
55 |
| - } |
56 |
| - |
57 |
| - let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { |
58 |
| - TerminatorKind::SwitchInt { |
59 |
| - discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), |
60 |
| - ref targets, |
61 |
| - .. |
62 |
| - } if targets.iter().len() == 1 => { |
63 |
| - let (value, target) = targets.iter().next().unwrap(); |
64 |
| - // We require that this block and the two possible target blocks all be |
65 |
| - // distinct. |
66 |
| - if target == targets.otherwise() |
67 |
| - || bb_idx == target |
68 |
| - || bb_idx == targets.otherwise() |
69 |
| - { |
70 |
| - continue; |
71 |
| - } |
72 |
| - (discr, value, target, targets.otherwise()) |
73 |
| - } |
74 |
| - // Only optimize switch int statements |
75 |
| - _ => continue, |
76 |
| - }; |
77 |
| - |
78 |
| - // Check that destinations are identical, and if not, then don't optimize this block |
79 |
| - if bbs[first].terminator().kind != bbs[second].terminator().kind { |
80 |
| - continue; |
| 167 | + // Check that blocks are assignments of consts to the same place or same statement, |
| 168 | + // and match up 1-1, if not don't optimize this block. |
| 169 | + let first_stmts = &bbs[first].statements; |
| 170 | + let second_stmts = &bbs[second].statements; |
| 171 | + if first_stmts.len() != second_stmts.len() { |
| 172 | + return false; |
| 173 | + } |
| 174 | + for (f, s) in iter::zip(first_stmts, second_stmts) { |
| 175 | + match (&f.kind, &s.kind) { |
| 176 | + // If two statements are exactly the same, we can optimize. |
| 177 | + (f_s, s_s) if f_s == s_s => {} |
| 178 | + |
| 179 | + // If two statements are const bool assignments to the same place, we can optimize. |
| 180 | + ( |
| 181 | + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), |
| 182 | + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), |
| 183 | + ) if lhs_f == lhs_s |
| 184 | + && f_c.const_.ty().is_bool() |
| 185 | + && s_c.const_.ty().is_bool() |
| 186 | + && f_c.const_.try_eval_bool(tcx, param_env).is_some() |
| 187 | + && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} |
| 188 | + |
| 189 | + // Otherwise we cannot optimize. Try another block. |
| 190 | + _ => return false, |
81 | 191 | }
|
| 192 | + } |
| 193 | + true |
| 194 | + } |
82 | 195 |
|
83 |
| - // Check that blocks are assignments of consts to the same place or same statement, |
84 |
| - // and match up 1-1, if not don't optimize this block. |
85 |
| - let first_stmts = &bbs[first].statements; |
86 |
| - let scnd_stmts = &bbs[second].statements; |
87 |
| - if first_stmts.len() != scnd_stmts.len() { |
88 |
| - continue; |
89 |
| - } |
90 |
| - for (f, s) in iter::zip(first_stmts, scnd_stmts) { |
91 |
| - match (&f.kind, &s.kind) { |
92 |
| - // If two statements are exactly the same, we can optimize. |
93 |
| - (f_s, s_s) if f_s == s_s => {} |
94 |
| - |
95 |
| - // If two statements are const bool assignments to the same place, we can optimize. |
96 |
| - ( |
97 |
| - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), |
98 |
| - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), |
99 |
| - ) if lhs_f == lhs_s |
100 |
| - && f_c.const_.ty().is_bool() |
101 |
| - && s_c.const_.ty().is_bool() |
102 |
| - && f_c.const_.try_eval_bool(tcx, param_env).is_some() |
103 |
| - && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} |
104 |
| - |
105 |
| - // Otherwise we cannot optimize. Try another block. |
106 |
| - _ => continue 'outer, |
107 |
| - } |
108 |
| - } |
109 |
| - // Take ownership of items now that we know we can optimize. |
110 |
| - let discr = discr.clone(); |
111 |
| - let discr_ty = discr.ty(&body.local_decls, tcx); |
112 |
| - |
113 |
| - // Introduce a temporary for the discriminant value. |
114 |
| - let source_info = bbs[bb_idx].terminator().source_info; |
115 |
| - let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); |
116 |
| - |
117 |
| - // We already checked that first and second are different blocks, |
118 |
| - // and bb_idx has a different terminator from both of them. |
119 |
| - let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); |
120 |
| - |
121 |
| - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { |
122 |
| - match (&f.kind, &s.kind) { |
123 |
| - (f_s, s_s) if f_s == s_s => (*f).clone(), |
124 |
| - |
125 |
| - ( |
126 |
| - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), |
127 |
| - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), |
128 |
| - ) => { |
129 |
| - // From earlier loop we know that we are dealing with bool constants only: |
130 |
| - let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
131 |
| - let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
132 |
| - if f_b == s_b { |
133 |
| - // Same value in both blocks. Use statement as is. |
134 |
| - (*f).clone() |
135 |
| - } else { |
136 |
| - // Different value between blocks. Make value conditional on switch condition. |
137 |
| - let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; |
138 |
| - let const_cmp = Operand::const_from_scalar( |
139 |
| - tcx, |
140 |
| - discr_ty, |
141 |
| - rustc_const_eval::interpret::Scalar::from_uint(val, size), |
142 |
| - rustc_span::DUMMY_SP, |
143 |
| - ); |
144 |
| - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; |
145 |
| - let rhs = Rvalue::BinaryOp( |
146 |
| - op, |
147 |
| - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), |
148 |
| - ); |
149 |
| - Statement { |
150 |
| - source_info: f.source_info, |
151 |
| - kind: StatementKind::Assign(Box::new((*lhs, rhs))), |
152 |
| - } |
| 196 | + fn new_stmts( |
| 197 | + &self, |
| 198 | + tcx: TyCtxt<'tcx>, |
| 199 | + targets: &SwitchTargets, |
| 200 | + param_env: ParamEnv<'tcx>, |
| 201 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 202 | + discr_local: Local, |
| 203 | + discr_ty: Ty<'tcx>, |
| 204 | + ) -> Vec<Statement<'tcx>> { |
| 205 | + let (val, first) = targets.iter().next().unwrap(); |
| 206 | + let second = targets.otherwise(); |
| 207 | + // We already checked that first and second are different blocks, |
| 208 | + // and bb_idx has a different terminator from both of them. |
| 209 | + let first = &bbs[first]; |
| 210 | + let second = &bbs[second]; |
| 211 | + |
| 212 | + let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { |
| 213 | + match (&f.kind, &s.kind) { |
| 214 | + (f_s, s_s) if f_s == s_s => (*f).clone(), |
| 215 | + |
| 216 | + ( |
| 217 | + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), |
| 218 | + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), |
| 219 | + ) => { |
| 220 | + // From earlier loop we know that we are dealing with bool constants only: |
| 221 | + let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
| 222 | + let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
| 223 | + if f_b == s_b { |
| 224 | + // Same value in both blocks. Use statement as is. |
| 225 | + (*f).clone() |
| 226 | + } else { |
| 227 | + // Different value between blocks. Make value conditional on switch condition. |
| 228 | + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; |
| 229 | + let const_cmp = Operand::const_from_scalar( |
| 230 | + tcx, |
| 231 | + discr_ty, |
| 232 | + rustc_const_eval::interpret::Scalar::from_uint(val, size), |
| 233 | + rustc_span::DUMMY_SP, |
| 234 | + ); |
| 235 | + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; |
| 236 | + let rhs = Rvalue::BinaryOp( |
| 237 | + op, |
| 238 | + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), |
| 239 | + ); |
| 240 | + Statement { |
| 241 | + source_info: f.source_info, |
| 242 | + kind: StatementKind::Assign(Box::new((*lhs, rhs))), |
153 | 243 | }
|
154 | 244 | }
|
155 |
| - |
156 |
| - _ => unreachable!(), |
157 | 245 | }
|
158 |
| - }); |
159 |
| - |
160 |
| - from.statements |
161 |
| - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); |
162 |
| - from.statements.push(Statement { |
163 |
| - source_info, |
164 |
| - kind: StatementKind::Assign(Box::new(( |
165 |
| - Place::from(discr_local), |
166 |
| - Rvalue::Use(discr), |
167 |
| - ))), |
168 |
| - }); |
169 |
| - from.statements.extend(new_stmts); |
170 |
| - from.statements |
171 |
| - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); |
172 |
| - from.terminator_mut().kind = first.terminator().kind.clone(); |
173 |
| - should_cleanup = true; |
174 |
| - } |
175 | 246 |
|
176 |
| - if should_cleanup { |
177 |
| - simplify_cfg(body); |
178 |
| - } |
| 247 | + _ => unreachable!(), |
| 248 | + } |
| 249 | + }); |
| 250 | + new_stmts.collect() |
179 | 251 | }
|
180 | 252 | }
|
0 commit comments