Skip to content

Commit 286a346

Browse files
committed
Auto merge of #75370 - simonvandel:optimize-if-condition-on-int-to-switch, r=oli-obk
New pass to optimize `if`conditions on integrals to switches on the integer Fixes #75144 Pass to convert `if` conditions on integrals into switches on the integral. For an example, it turns something like ``` _3 = Eq(move _4, const 43i32); StorageDead(_4); switchInt(_3) -> [false: bb2, otherwise: bb3]; ``` into: ``` switchInt(_4) -> [43i32: bb3, otherwise: bb2]; ```
2 parents 65d071e + 23dda1b commit 286a346

13 files changed

+662
-0
lines changed

src/librustc_middle/mir/interpret/value.rs

+10
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,11 @@ impl<'tcx, Tag> Scalar<Tag> {
503503
self.to_unsigned_with_bit_width(64).map(|v| u64::try_from(v).unwrap())
504504
}
505505

506+
/// Converts the scalar to produce an `u128`. Fails if the scalar is a pointer.
507+
pub fn to_u128(self) -> InterpResult<'static, u128> {
508+
self.to_unsigned_with_bit_width(128)
509+
}
510+
506511
pub fn to_machine_usize(self, cx: &impl HasDataLayout) -> InterpResult<'static, u64> {
507512
let b = self.to_bits(cx.data_layout().pointer_size)?;
508513
Ok(u64::try_from(b).unwrap())
@@ -535,6 +540,11 @@ impl<'tcx, Tag> Scalar<Tag> {
535540
self.to_signed_with_bit_width(64).map(|v| i64::try_from(v).unwrap())
536541
}
537542

543+
/// Converts the scalar to produce an `i128`. Fails if the scalar is a pointer.
544+
pub fn to_i128(self) -> InterpResult<'static, i128> {
545+
self.to_signed_with_bit_width(128)
546+
}
547+
538548
pub fn to_machine_isize(self, cx: &impl HasDataLayout) -> InterpResult<'static, i64> {
539549
let sz = cx.data_layout().pointer_size;
540550
let b = self.to_bits(sz)?;

src/librustc_middle/mir/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,15 @@ pub enum StatementKind<'tcx> {
14301430
Nop,
14311431
}
14321432

1433+
impl<'tcx> StatementKind<'tcx> {
1434+
pub fn as_assign_mut(&mut self) -> Option<&mut Box<(Place<'tcx>, Rvalue<'tcx>)>> {
1435+
match self {
1436+
StatementKind::Assign(x) => Some(x),
1437+
_ => None,
1438+
}
1439+
}
1440+
}
1441+
14331442
/// Describes what kind of retag is to be performed.
14341443
#[derive(Copy, Clone, TyEncodable, TyDecodable, Debug, PartialEq, Eq, HashStable)]
14351444
pub enum RetagKind {
@@ -1843,6 +1852,10 @@ impl<'tcx> Operand<'tcx> {
18431852
})
18441853
}
18451854

1855+
pub fn is_move(&self) -> bool {
1856+
matches!(self, Operand::Move(..))
1857+
}
1858+
18461859
/// Convenience helper to make a literal-like constant from a given scalar value.
18471860
/// Since this is used to synthesize MIR, assumes `user_ty` is None.
18481861
pub fn const_from_scalar(

src/librustc_mir/transform/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pub mod required_consts;
3939
pub mod rustc_peek;
4040
pub mod simplify;
4141
pub mod simplify_branches;
42+
pub mod simplify_comparison_integral;
4243
pub mod simplify_try;
4344
pub mod uninhabited_enum_branching;
4445
pub mod unreachable_prop;
@@ -456,6 +457,7 @@ fn run_optimization_passes<'tcx>(
456457
&match_branches::MatchBranchSimplification,
457458
&const_prop::ConstProp,
458459
&simplify_branches::SimplifyBranches::new("after-const-prop"),
460+
&simplify_comparison_integral::SimplifyComparisonIntegral,
459461
&simplify_try::SimplifyArmIdentity,
460462
&simplify_try::SimplifyBranchSame,
461463
&copy_prop::CopyPropagation,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
use super::{MirPass, MirSource};
2+
use rustc_middle::{
3+
mir::{
4+
interpret::Scalar, BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement,
5+
StatementKind, TerminatorKind,
6+
},
7+
ty::{Ty, TyCtxt},
8+
};
9+
10+
/// Pass to convert `if` conditions on integrals into switches on the integral.
11+
/// For an example, it turns something like
12+
///
13+
/// ```
14+
/// _3 = Eq(move _4, const 43i32);
15+
/// StorageDead(_4);
16+
/// switchInt(_3) -> [false: bb2, otherwise: bb3];
17+
/// ```
18+
///
19+
/// into:
20+
///
21+
/// ```
22+
/// switchInt(_4) -> [43i32: bb3, otherwise: bb2];
23+
/// ```
24+
pub struct SimplifyComparisonIntegral;
25+
26+
impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral {
27+
fn run_pass(&self, _: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
28+
trace!("Running SimplifyComparisonIntegral on {:?}", source);
29+
30+
let helper = OptimizationFinder { body };
31+
let opts = helper.find_optimizations();
32+
let mut storage_deads_to_insert = vec![];
33+
let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
34+
for opt in opts {
35+
trace!("SUCCESS: Applying {:?}", opt);
36+
// replace terminator with a switchInt that switches on the integer directly
37+
let bbs = &mut body.basic_blocks_mut();
38+
let bb = &mut bbs[opt.bb_idx];
39+
// We only use the bits for the untyped, not length checked `values` field. Thus we are
40+
// not using any of the convenience wrappers here and directly access the bits.
41+
let new_value = match opt.branch_value_scalar {
42+
Scalar::Raw { data, .. } => data,
43+
Scalar::Ptr(_) => continue,
44+
};
45+
const FALSE: u128 = 0;
46+
let mut new_targets = opt.targets.clone();
47+
let first_is_false_target = opt.values[0] == FALSE;
48+
match opt.op {
49+
BinOp::Eq => {
50+
// if the assignment was Eq we want the true case to be first
51+
if first_is_false_target {
52+
new_targets.swap(0, 1);
53+
}
54+
}
55+
BinOp::Ne => {
56+
// if the assignment was Ne we want the false case to be first
57+
if !first_is_false_target {
58+
new_targets.swap(0, 1);
59+
}
60+
}
61+
_ => unreachable!(),
62+
}
63+
64+
let terminator = bb.terminator_mut();
65+
66+
// add StorageDead for the place switched on at the top of each target
67+
for bb_idx in new_targets.iter() {
68+
storage_deads_to_insert.push((
69+
*bb_idx,
70+
Statement {
71+
source_info: terminator.source_info,
72+
kind: StatementKind::StorageDead(opt.to_switch_on.local),
73+
},
74+
));
75+
}
76+
77+
terminator.kind = TerminatorKind::SwitchInt {
78+
discr: Operand::Move(opt.to_switch_on),
79+
switch_ty: opt.branch_value_ty,
80+
values: vec![new_value].into(),
81+
targets: new_targets,
82+
};
83+
84+
// delete comparison statement if it the value being switched on was moved, which means it can not be user later on
85+
if opt.can_remove_bin_op_stmt {
86+
bb.statements[opt.bin_op_stmt_idx].make_nop();
87+
} else {
88+
// if the integer being compared to a const integral is being moved into the comparison,
89+
// e.g `_2 = Eq(move _3, const 'x');`
90+
// we want to avoid making a double move later on in the switchInt on _3.
91+
// So to avoid `switchInt(move _3) -> ['x': bb2, otherwise: bb1];`,
92+
// we convert the move in the comparison statement to a copy.
93+
94+
// unwrap is safe as we know this statement is an assign
95+
let box (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
96+
97+
use Operand::*;
98+
match rhs {
99+
Rvalue::BinaryOp(_, ref mut left @ Move(_), Constant(_)) => {
100+
*left = Copy(opt.to_switch_on);
101+
}
102+
Rvalue::BinaryOp(_, Constant(_), ref mut right @ Move(_)) => {
103+
*right = Copy(opt.to_switch_on);
104+
}
105+
_ => (),
106+
}
107+
}
108+
109+
// remove StorageDead (if it exists) being used in the assign of the comparison
110+
for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
111+
if !matches!(stmt.kind, StatementKind::StorageDead(local) if local == opt.to_switch_on.local)
112+
{
113+
continue;
114+
}
115+
storage_deads_to_remove.push((stmt_idx, opt.bb_idx))
116+
}
117+
}
118+
119+
for (idx, bb_idx) in storage_deads_to_remove {
120+
body.basic_blocks_mut()[bb_idx].statements[idx].make_nop();
121+
}
122+
123+
for (idx, stmt) in storage_deads_to_insert {
124+
body.basic_blocks_mut()[idx].statements.insert(0, stmt);
125+
}
126+
}
127+
}
128+
129+
struct OptimizationFinder<'a, 'tcx> {
130+
body: &'a Body<'tcx>,
131+
}
132+
133+
impl<'a, 'tcx> OptimizationFinder<'a, 'tcx> {
134+
fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
135+
self.body
136+
.basic_blocks()
137+
.iter_enumerated()
138+
.filter_map(|(bb_idx, bb)| {
139+
// find switch
140+
let (place_switched_on, values, targets, place_switched_on_moved) = match &bb
141+
.terminator()
142+
.kind
143+
{
144+
rustc_middle::mir::TerminatorKind::SwitchInt {
145+
discr, values, targets, ..
146+
} => Some((discr.place()?, values, targets, discr.is_move())),
147+
_ => None,
148+
}?;
149+
150+
// find the statement that assigns the place being switched on
151+
bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
152+
match &stmt.kind {
153+
rustc_middle::mir::StatementKind::Assign(box (lhs, rhs))
154+
if *lhs == place_switched_on =>
155+
{
156+
match rhs {
157+
Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), left, right) => {
158+
let (branch_value_scalar, branch_value_ty, to_switch_on) =
159+
find_branch_value_info(left, right)?;
160+
161+
Some(OptimizationInfo {
162+
bin_op_stmt_idx: stmt_idx,
163+
bb_idx,
164+
can_remove_bin_op_stmt: place_switched_on_moved,
165+
to_switch_on,
166+
branch_value_scalar,
167+
branch_value_ty,
168+
op: *op,
169+
values: values.clone().into_owned(),
170+
targets: targets.clone(),
171+
})
172+
}
173+
_ => None,
174+
}
175+
}
176+
_ => None,
177+
}
178+
})
179+
})
180+
.collect()
181+
}
182+
}
183+
184+
fn find_branch_value_info<'tcx>(
185+
left: &Operand<'tcx>,
186+
right: &Operand<'tcx>,
187+
) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
188+
// check that either left or right is a constant.
189+
// if any are, we can use the other to switch on, and the constant as a value in a switch
190+
use Operand::*;
191+
match (left, right) {
192+
(Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
193+
| (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
194+
let branch_value_ty = branch_value.literal.ty;
195+
// we only want to apply this optimization if we are matching on integrals (and chars), as it is not possible to switch on floats
196+
if !branch_value_ty.is_integral() && !branch_value_ty.is_char() {
197+
return None;
198+
};
199+
let branch_value_scalar = branch_value.literal.val.try_to_scalar()?;
200+
Some((branch_value_scalar, branch_value_ty, *to_switch_on))
201+
}
202+
_ => None,
203+
}
204+
}
205+
206+
#[derive(Debug)]
207+
struct OptimizationInfo<'tcx> {
208+
/// Basic block to apply the optimization
209+
bb_idx: BasicBlock,
210+
/// Statement index of Eq/Ne assignment that can be removed. None if the assignment can not be removed - i.e the statement is used later on
211+
bin_op_stmt_idx: usize,
212+
/// Can remove Eq/Ne assignment
213+
can_remove_bin_op_stmt: bool,
214+
/// Place that needs to be switched on. This place is of type integral
215+
to_switch_on: Place<'tcx>,
216+
/// Constant to use in switch target value
217+
branch_value_scalar: Scalar,
218+
/// Type of the constant value
219+
branch_value_ty: Ty<'tcx>,
220+
/// Either Eq or Ne
221+
op: BinOp,
222+
/// Current values used in the switch target. This needs to be replaced with the branch_value
223+
values: Vec<u128>,
224+
/// Current targets used in the switch
225+
targets: Vec<BasicBlock>,
226+
}

src/test/mir-opt/if-condition-int.rs

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// compile-flags: -O
2+
// EMIT_MIR if_condition_int.opt_u32.SimplifyComparisonIntegral.diff
3+
// EMIT_MIR if_condition_int.opt_negative.SimplifyComparisonIntegral.diff
4+
// EMIT_MIR if_condition_int.opt_char.SimplifyComparisonIntegral.diff
5+
// EMIT_MIR if_condition_int.opt_i8.SimplifyComparisonIntegral.diff
6+
// EMIT_MIR if_condition_int.dont_opt_bool.SimplifyComparisonIntegral.diff
7+
// EMIT_MIR if_condition_int.opt_multiple_ifs.SimplifyComparisonIntegral.diff
8+
// EMIT_MIR if_condition_int.dont_remove_comparison.SimplifyComparisonIntegral.diff
9+
// EMIT_MIR if_condition_int.dont_opt_floats.SimplifyComparisonIntegral.diff
10+
11+
fn opt_u32(x: u32) -> u32 {
12+
if x == 42 { 0 } else { 1 }
13+
}
14+
15+
// don't opt: it is already optimal to switch on the bool
16+
fn dont_opt_bool(x: bool) -> u32 {
17+
if x { 0 } else { 1 }
18+
}
19+
20+
fn opt_char(x: char) -> u32 {
21+
if x == 'x' { 0 } else { 1 }
22+
}
23+
24+
fn opt_i8(x: i8) -> u32 {
25+
if x == 42 { 0 } else { 1 }
26+
}
27+
28+
fn opt_negative(x: i32) -> u32 {
29+
if x == -42 { 0 } else { 1 }
30+
}
31+
32+
fn opt_multiple_ifs(x: u32) -> u32 {
33+
if x == 42 {
34+
0
35+
} else if x != 21 {
36+
1
37+
} else {
38+
2
39+
}
40+
}
41+
42+
// test that we optimize, but do not remove the b statement, as that is used later on
43+
fn dont_remove_comparison(a: i8) -> i32 {
44+
let b = a == 17;
45+
match b {
46+
false => 10 + b as i32,
47+
true => 100 + b as i32,
48+
}
49+
}
50+
51+
// test that we do not optimize on floats
52+
fn dont_opt_floats(a: f32) -> i32 {
53+
if a == -42.0 { 0 } else { 1 }
54+
}
55+
56+
fn main() {
57+
opt_u32(0);
58+
opt_char('0');
59+
opt_i8(22);
60+
dont_opt_bool(false);
61+
opt_negative(0);
62+
opt_multiple_ifs(0);
63+
dont_remove_comparison(11);
64+
dont_opt_floats(1.0);
65+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
- // MIR for `dont_opt_bool` before SimplifyComparisonIntegral
2+
+ // MIR for `dont_opt_bool` after SimplifyComparisonIntegral
3+
4+
fn dont_opt_bool(_1: bool) -> u32 {
5+
debug x => _1; // in scope 0 at $DIR/if-condition-int.rs:16:18: 16:19
6+
let mut _0: u32; // return place in scope 0 at $DIR/if-condition-int.rs:16:30: 16:33
7+
let mut _2: bool; // in scope 0 at $DIR/if-condition-int.rs:17:8: 17:9
8+
9+
bb0: {
10+
StorageLive(_2); // scope 0 at $DIR/if-condition-int.rs:17:8: 17:9
11+
_2 = _1; // scope 0 at $DIR/if-condition-int.rs:17:8: 17:9
12+
switchInt(_2) -> [false: bb1, otherwise: bb2]; // scope 0 at $DIR/if-condition-int.rs:17:5: 17:26
13+
}
14+
15+
bb1: {
16+
_0 = const 1_u32; // scope 0 at $DIR/if-condition-int.rs:17:23: 17:24
17+
goto -> bb3; // scope 0 at $DIR/if-condition-int.rs:17:5: 17:26
18+
}
19+
20+
bb2: {
21+
_0 = const 0_u32; // scope 0 at $DIR/if-condition-int.rs:17:12: 17:13
22+
goto -> bb3; // scope 0 at $DIR/if-condition-int.rs:17:5: 17:26
23+
}
24+
25+
bb3: {
26+
StorageDead(_2); // scope 0 at $DIR/if-condition-int.rs:18:1: 18:2
27+
return; // scope 0 at $DIR/if-condition-int.rs:18:2: 18:2
28+
}
29+
}
30+

0 commit comments

Comments
 (0)