1
1
use rustc_index:: IndexVec ;
2
2
use rustc_middle:: mir:: * ;
3
- use rustc_middle:: ty:: { ParamEnv , Ty , TyCtxt } ;
3
+ use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
4
4
use std:: iter;
5
5
6
6
use super :: simplify:: simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
38
38
should_cleanup = true ;
39
39
continue ;
40
40
}
41
+ if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
42
+ {
43
+ should_cleanup = true ;
44
+ continue ;
45
+ }
41
46
}
42
47
43
48
if should_cleanup {
@@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
47
52
}
48
53
49
54
trait SimplifyMatch < ' tcx > {
55
+ /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise.
56
+ /// Generic code is written here, and we generally don't need a custom implementation.
50
57
fn simplify (
51
- & self ,
58
+ & mut self ,
52
59
tcx : TyCtxt < ' tcx > ,
53
60
local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
54
61
bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
@@ -72,9 +79,9 @@ trait SimplifyMatch<'tcx> {
72
79
let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
73
80
let discr_local = local_decls. push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
74
81
75
- // We already checked that first and second are different blocks,
82
+ // We already checked that targets are different blocks,
76
83
// and bb_idx has a different terminator from both of them.
77
- let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local. clone ( ) , discr_ty) ;
84
+ let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local, discr_ty) ;
78
85
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
79
86
let ( from, first) = bbs. pick2_mut ( switch_bb_idx, first) ;
80
87
from. statements
@@ -91,7 +98,7 @@ trait SimplifyMatch<'tcx> {
91
98
}
92
99
93
100
fn can_simplify (
94
- & self ,
101
+ & mut self ,
95
102
tcx : TyCtxt < ' tcx > ,
96
103
targets : & SwitchTargets ,
97
104
param_env : ParamEnv < ' tcx > ,
@@ -144,7 +151,7 @@ struct SimplifyToIf;
144
151
/// ```
145
152
impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToIf {
146
153
fn can_simplify (
147
- & self ,
154
+ & mut self ,
148
155
tcx : TyCtxt < ' tcx > ,
149
156
targets : & SwitchTargets ,
150
157
param_env : ParamEnv < ' tcx > ,
@@ -250,3 +257,210 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250
257
new_stmts. collect ( )
251
258
}
252
259
}
260
+
261
+ #[ derive( Default ) ]
262
+ struct SimplifyToExp {
263
+ transfrom_types : Vec < TransfromType > ,
264
+ }
265
+
266
+ #[ derive( Clone , Copy ) ]
267
+ enum CompareType < ' tcx , ' a > {
268
+ Same ( & ' a StatementKind < ' tcx > ) ,
269
+ Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
270
+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
271
+ }
272
+
273
+ enum TransfromType {
274
+ Same ,
275
+ Eq ,
276
+ Discr ,
277
+ }
278
+
279
+ impl From < CompareType < ' _ , ' _ > > for TransfromType {
280
+ fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
281
+ match compare_type {
282
+ CompareType :: Same ( _) => TransfromType :: Same ,
283
+ CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
284
+ CompareType :: Discr ( _, _) => TransfromType :: Discr ,
285
+ }
286
+ }
287
+ }
288
+
289
+ /// If we find that the value of match is the same as the assignment,
290
+ /// merge a target block statements into the source block,
291
+ /// using cast to transform different integer types.
292
+ ///
293
+ /// For example:
294
+ ///
295
+ /// ```ignore (MIR)
296
+ /// bb0: {
297
+ /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
298
+ /// }
299
+ ///
300
+ /// bb1: {
301
+ /// unreachable;
302
+ /// }
303
+ ///
304
+ /// bb2: {
305
+ /// _0 = const 1_i16;
306
+ /// goto -> bb5;
307
+ /// }
308
+ ///
309
+ /// bb3: {
310
+ /// _0 = const 2_i16;
311
+ /// goto -> bb5;
312
+ /// }
313
+ ///
314
+ /// bb4: {
315
+ /// _0 = const 3_i16;
316
+ /// goto -> bb5;
317
+ /// }
318
+ /// ```
319
+ ///
320
+ /// into:
321
+ ///
322
+ /// ```ignore (MIR)
323
+ /// bb0: {
324
+ /// _0 = _3 as i16 (IntToInt);
325
+ /// goto -> bb5;
326
+ /// }
327
+ /// ```
328
+ impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToExp {
329
+ fn can_simplify (
330
+ & mut self ,
331
+ tcx : TyCtxt < ' tcx > ,
332
+ targets : & SwitchTargets ,
333
+ param_env : ParamEnv < ' tcx > ,
334
+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
335
+ ) -> bool {
336
+ if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
337
+ return false ;
338
+ }
339
+ // We require that the possible target blocks all be distinct.
340
+ if !targets. is_distinct ( ) {
341
+ return false ;
342
+ }
343
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
344
+ return false ;
345
+ }
346
+ let mut iter = targets. iter ( ) ;
347
+ let ( first_val, first_target) = iter. next ( ) . unwrap ( ) ;
348
+ let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
349
+ // Check that destinations are identical, and if not, then don't optimize this block
350
+ if !targets
351
+ . iter ( )
352
+ . all ( |( _, other_target) | first_terminator_kind == & bbs[ other_target] . terminator ( ) . kind )
353
+ {
354
+ return false ;
355
+ }
356
+
357
+ let first_stmts = & bbs[ first_target] . statements ;
358
+ let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
359
+ let second_stmts = & bbs[ second_target] . statements ;
360
+ if first_stmts. len ( ) != second_stmts. len ( ) {
361
+ return false ;
362
+ }
363
+
364
+ let mut compare_types = Vec :: new ( ) ;
365
+ for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
366
+ let compare_type = match ( & f. kind , & s. kind ) {
367
+ // If two statements are exactly the same, we can optimize.
368
+ ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
369
+
370
+ // If two statements are assignments with the match values to the same place, we can optimize.
371
+ (
372
+ StatementKind :: Assign ( box ( lhs_f, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
373
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
374
+ ) if lhs_f == lhs_s
375
+ && f_c. const_ . ty ( ) == s_c. const_ . ty ( )
376
+ && f_c. const_ . ty ( ) . is_integral ( ) =>
377
+ {
378
+ match (
379
+ f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
380
+ s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
381
+ ) {
382
+ ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
383
+ ( Some ( f) , Some ( s) )
384
+ if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
385
+ && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
386
+ {
387
+ CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
388
+ }
389
+ _ => return false ,
390
+ }
391
+ }
392
+
393
+ // Otherwise we cannot optimize. Try another block.
394
+ _ => return false ,
395
+ } ;
396
+ compare_types. push ( compare_type) ;
397
+ }
398
+
399
+ for ( other_val, other_target) in iter {
400
+ let other_stmts = & bbs[ other_target] . statements ;
401
+ if compare_types. len ( ) != other_stmts. len ( ) {
402
+ return false ;
403
+ }
404
+ for ( f, s) in iter:: zip ( & compare_types, other_stmts) {
405
+ match ( * f, & s. kind ) {
406
+ ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
407
+ (
408
+ CompareType :: Eq ( lhs_f, f_ty, val) ,
409
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
410
+ ) if lhs_f == lhs_s
411
+ && s_c. const_ . ty ( ) == f_ty
412
+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
413
+ (
414
+ CompareType :: Discr ( lhs_f, f_ty) ,
415
+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
416
+ ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
417
+ let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
418
+ return false ;
419
+ } ;
420
+ if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
421
+ return false ;
422
+ }
423
+ }
424
+ _ => return false ,
425
+ }
426
+ }
427
+ }
428
+ self . transfrom_types = compare_types. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
429
+ true
430
+ }
431
+
432
+ fn new_stmts (
433
+ & self ,
434
+ _tcx : TyCtxt < ' tcx > ,
435
+ targets : & SwitchTargets ,
436
+ _param_env : ParamEnv < ' tcx > ,
437
+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
438
+ discr_local : Local ,
439
+ discr_ty : Ty < ' tcx > ,
440
+ ) -> Vec < Statement < ' tcx > > {
441
+ let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
442
+ let first = & bbs[ first] ;
443
+
444
+ let new_stmts =
445
+ iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
446
+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
447
+ (
448
+ TransfromType :: Discr ,
449
+ StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
450
+ ) => {
451
+ let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
452
+ let r_val = if f_c. const_ . ty ( ) == discr_ty {
453
+ Rvalue :: Use ( operand)
454
+ } else {
455
+ Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
456
+ } ;
457
+ Statement {
458
+ source_info : s. source_info ,
459
+ kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
460
+ }
461
+ }
462
+ _ => unreachable ! ( ) ,
463
+ } ) ;
464
+ new_stmts. collect ( )
465
+ }
466
+ }
0 commit comments