Skip to content

Commit 70ca429

Browse files
committed
Transforms match into an assignment statement
1 parent 6dba010 commit 70ca429

9 files changed

+367
-116
lines changed

compiler/rustc_middle/src/mir/terminator.rs

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ impl SwitchTargets {
7474
pub fn target_for_value(&self, value: u128) -> BasicBlock {
7575
self.iter().find_map(|(v, t)| (v == value).then_some(t)).unwrap_or_else(|| self.otherwise())
7676
}
77+
78+
/// Returns true if all targets (including the fallback target) are distinct.
79+
#[inline]
80+
pub fn is_distinct(&self) -> bool {
81+
self.targets.iter().collect::<FxHashSet<_>>().len() == self.targets.len()
82+
}
7783
}
7884

7985
pub struct SwitchTargetsIter<'a> {

compiler/rustc_mir_transform/src/match_branches.rs

+220-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rustc_index::IndexVec;
22
use rustc_middle::mir::*;
3-
use rustc_middle::ty::{ParamEnv, Ty, TyCtxt};
3+
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
44
use std::iter;
55

66
use super::simplify::simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838
should_cleanup = true;
3939
continue;
4040
}
41+
if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env)
42+
{
43+
should_cleanup = true;
44+
continue;
45+
}
4146
}
4247

4348
if should_cleanup {
@@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4752
}
4853

4954
trait SimplifyMatch<'tcx> {
55+
/// Simplifies a match statement, returning true if the simplification succeeds, false otherwise.
56+
/// Generic code is written here, and we generally don't need a custom implementation.
5057
fn simplify(
51-
&self,
58+
&mut self,
5259
tcx: TyCtxt<'tcx>,
5360
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
5461
bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
@@ -72,9 +79,9 @@ trait SimplifyMatch<'tcx> {
7279
let source_info = bbs[switch_bb_idx].terminator().source_info;
7380
let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span));
7481

75-
// We already checked that first and second are different blocks,
82+
// We already checked that targets are different blocks,
7683
// and bb_idx has a different terminator from both of them.
77-
let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty);
84+
let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local, discr_ty);
7885
let (_, first) = targets.iter().next().unwrap();
7986
let (from, first) = bbs.pick2_mut(switch_bb_idx, first);
8087
from.statements
@@ -91,7 +98,7 @@ trait SimplifyMatch<'tcx> {
9198
}
9299

93100
fn can_simplify(
94-
&self,
101+
&mut self,
95102
tcx: TyCtxt<'tcx>,
96103
targets: &SwitchTargets,
97104
param_env: ParamEnv<'tcx>,
@@ -144,7 +151,7 @@ struct SimplifyToIf;
144151
/// ```
145152
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
146153
fn can_simplify(
147-
&self,
154+
&mut self,
148155
tcx: TyCtxt<'tcx>,
149156
targets: &SwitchTargets,
150157
param_env: ParamEnv<'tcx>,
@@ -250,3 +257,210 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250257
new_stmts.collect()
251258
}
252259
}
260+
261+
#[derive(Default)]
262+
struct SimplifyToExp {
263+
transfrom_types: Vec<TransfromType>,
264+
}
265+
266+
#[derive(Clone, Copy)]
267+
enum CompareType<'tcx, 'a> {
268+
Same(&'a StatementKind<'tcx>),
269+
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
270+
Discr(&'a Place<'tcx>, Ty<'tcx>),
271+
}
272+
273+
enum TransfromType {
274+
Same,
275+
Eq,
276+
Discr,
277+
}
278+
279+
impl From<CompareType<'_, '_>> for TransfromType {
280+
fn from(compare_type: CompareType<'_, '_>) -> Self {
281+
match compare_type {
282+
CompareType::Same(_) => TransfromType::Same,
283+
CompareType::Eq(_, _, _) => TransfromType::Eq,
284+
CompareType::Discr(_, _) => TransfromType::Discr,
285+
}
286+
}
287+
}
288+
289+
/// If we find that the value of match is the same as the assignment,
290+
/// merge a target block statements into the source block,
291+
/// using cast to transform different integer types.
292+
///
293+
/// For example:
294+
///
295+
/// ```ignore (MIR)
296+
/// bb0: {
297+
/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
298+
/// }
299+
///
300+
/// bb1: {
301+
/// unreachable;
302+
/// }
303+
///
304+
/// bb2: {
305+
/// _0 = const 1_i16;
306+
/// goto -> bb5;
307+
/// }
308+
///
309+
/// bb3: {
310+
/// _0 = const 2_i16;
311+
/// goto -> bb5;
312+
/// }
313+
///
314+
/// bb4: {
315+
/// _0 = const 3_i16;
316+
/// goto -> bb5;
317+
/// }
318+
/// ```
319+
///
320+
/// into:
321+
///
322+
/// ```ignore (MIR)
323+
/// bb0: {
324+
/// _0 = _3 as i16 (IntToInt);
325+
/// goto -> bb5;
326+
/// }
327+
/// ```
328+
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
329+
fn can_simplify(
330+
&mut self,
331+
tcx: TyCtxt<'tcx>,
332+
targets: &SwitchTargets,
333+
param_env: ParamEnv<'tcx>,
334+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
335+
) -> bool {
336+
if targets.iter().len() < 2 || targets.iter().len() > 64 {
337+
return false;
338+
}
339+
// We require that the possible target blocks all be distinct.
340+
if !targets.is_distinct() {
341+
return false;
342+
}
343+
if !bbs[targets.otherwise()].is_empty_unreachable() {
344+
return false;
345+
}
346+
let mut iter = targets.iter();
347+
let (first_val, first_target) = iter.next().unwrap();
348+
let first_terminator_kind = &bbs[first_target].terminator().kind;
349+
// Check that destinations are identical, and if not, then don't optimize this block
350+
if !targets
351+
.iter()
352+
.all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
353+
{
354+
return false;
355+
}
356+
357+
let first_stmts = &bbs[first_target].statements;
358+
let (second_val, second_target) = iter.next().unwrap();
359+
let second_stmts = &bbs[second_target].statements;
360+
if first_stmts.len() != second_stmts.len() {
361+
return false;
362+
}
363+
364+
let mut compare_types = Vec::new();
365+
for (f, s) in iter::zip(first_stmts, second_stmts) {
366+
let compare_type = match (&f.kind, &s.kind) {
367+
// If two statements are exactly the same, we can optimize.
368+
(f_s, s_s) if f_s == s_s => CompareType::Same(f_s),
369+
370+
// If two statements are assignments with the match values to the same place, we can optimize.
371+
(
372+
StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
373+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
374+
) if lhs_f == lhs_s
375+
&& f_c.const_.ty() == s_c.const_.ty()
376+
&& f_c.const_.ty().is_integral() =>
377+
{
378+
match (
379+
f_c.const_.try_eval_scalar_int(tcx, param_env),
380+
s_c.const_.try_eval_scalar_int(tcx, param_env),
381+
) {
382+
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
383+
(Some(f), Some(s))
384+
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
385+
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
386+
{
387+
CompareType::Discr(lhs_f, f_c.const_.ty())
388+
}
389+
_ => return false,
390+
}
391+
}
392+
393+
// Otherwise we cannot optimize. Try another block.
394+
_ => return false,
395+
};
396+
compare_types.push(compare_type);
397+
}
398+
399+
for (other_val, other_target) in iter {
400+
let other_stmts = &bbs[other_target].statements;
401+
if compare_types.len() != other_stmts.len() {
402+
return false;
403+
}
404+
for (f, s) in iter::zip(&compare_types, other_stmts) {
405+
match (*f, &s.kind) {
406+
(CompareType::Same(f_s), s_s) if f_s == s_s => {}
407+
(
408+
CompareType::Eq(lhs_f, f_ty, val),
409+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
410+
) if lhs_f == lhs_s
411+
&& s_c.const_.ty() == f_ty
412+
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
413+
(
414+
CompareType::Discr(lhs_f, f_ty),
415+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
416+
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
417+
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
418+
return false;
419+
};
420+
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
421+
return false;
422+
}
423+
}
424+
_ => return false,
425+
}
426+
}
427+
}
428+
self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect();
429+
true
430+
}
431+
432+
fn new_stmts(
433+
&self,
434+
_tcx: TyCtxt<'tcx>,
435+
targets: &SwitchTargets,
436+
_param_env: ParamEnv<'tcx>,
437+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
438+
discr_local: Local,
439+
discr_ty: Ty<'tcx>,
440+
) -> Vec<Statement<'tcx>> {
441+
let (_, first) = targets.iter().next().unwrap();
442+
let first = &bbs[first];
443+
444+
let new_stmts =
445+
iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) {
446+
(TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(),
447+
(
448+
TransfromType::Discr,
449+
StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
450+
) => {
451+
let operand = Operand::Copy(Place::from(discr_local));
452+
let r_val = if f_c.const_.ty() == discr_ty {
453+
Rvalue::Use(operand)
454+
} else {
455+
Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
456+
};
457+
Statement {
458+
source_info: s.source_info,
459+
kind: StatementKind::Assign(Box::new((*lhs, r_val))),
460+
}
461+
}
462+
_ => unreachable!(),
463+
});
464+
new_stmts.collect()
465+
}
466+
}

tests/codegen/match-optimized.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 {
2626
// CHECK-NEXT: store i8 1, ptr %_0, align 1
2727
// CHECK-NEXT: br label %[[EXIT]]
2828
// CHECK: [[C]]:
29-
// CHECK-NEXT: store i8 2, ptr %_0, align 1
29+
// CHECK-NEXT: store i8 3, ptr %_0, align 1
3030
// CHECK-NEXT: br label %[[EXIT]]
3131
match e {
3232
E::A => 0,
3333
E::B => 1,
34-
E::C => 2,
34+
E::C => 3,
3535
}
3636
}
3737

tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff

+33-28
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,42 @@
55
debug i => _1;
66
let mut _0: u128;
77
let mut _2: i128;
8+
+ let mut _3: i128;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
12-
}
13-
14-
bb1: {
15-
unreachable;
16-
}
17-
18-
bb2: {
19-
_0 = const _;
20-
goto -> bb6;
21-
}
22-
23-
bb3: {
24-
_0 = const 1_u128;
25-
goto -> bb6;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_u128;
30-
goto -> bb6;
31-
}
32-
33-
bb5: {
34-
_0 = const 3_u128;
35-
goto -> bb6;
36-
}
37-
38-
bb6: {
12+
- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb2, otherwise: bb1];
13+
- }
14+
-
15+
- bb1: {
16+
- unreachable;
17+
- }
18+
-
19+
- bb2: {
20+
- _0 = const _;
21+
- goto -> bb6;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const 1_u128;
26+
- goto -> bb6;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_u128;
31+
- goto -> bb6;
32+
- }
33+
-
34+
- bb5: {
35+
- _0 = const 3_u128;
36+
- goto -> bb6;
37+
- }
38+
-
39+
- bb6: {
40+
+ StorageLive(_3);
41+
+ _3 = move _2;
42+
+ _0 = _3 as u128 (IntToInt);
43+
+ StorageDead(_3);
3944
return;
4045
}
4146
}

0 commit comments

Comments
 (0)