@@ -39,7 +39,7 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
39
39
40
40
let mut curr: usize = 0 ;
41
41
for bb in mir. basic_blocks_mut ( ) {
42
- let idx = match get_aggregate_statement ( curr, & bb. statements ) {
42
+ let idx = match get_aggregate_statement_index ( curr, & bb. statements ) {
43
43
Some ( idx) => idx,
44
44
None => continue ,
45
45
} ;
@@ -48,7 +48,11 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
48
48
let src_info = bb. statements [ idx] . source_info ;
49
49
let suffix_stmts = bb. statements . split_off ( idx+1 ) ;
50
50
let orig_stmt = bb. statements . pop ( ) . unwrap ( ) ;
51
- let StatementKind :: Assign ( ref lhs, ref rhs) = orig_stmt. kind ;
51
+ let ( lhs, rhs) = match orig_stmt. kind {
52
+ StatementKind :: Assign ( ref lhs, ref rhs) => ( lhs, rhs) ,
53
+ StatementKind :: SetDiscriminant { .. } =>
54
+ span_bug ! ( src_info. span, "expected aggregate, not {:?}" , orig_stmt. kind) ,
55
+ } ;
52
56
let ( agg_kind, operands) = match rhs {
53
57
& Rvalue :: Aggregate ( ref agg_kind, ref operands) => ( agg_kind, operands) ,
54
58
_ => span_bug ! ( src_info. span, "expected aggregate, not {:?}" , rhs) ,
@@ -64,10 +68,14 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
64
68
let ty = variant_def. fields [ i] . ty ( tcx, substs) ;
65
69
let rhs = Rvalue :: Use ( op. clone ( ) ) ;
66
70
67
- // since we don't handle enums, we don't need a cast
68
- let lhs_cast = lhs. clone ( ) ;
69
-
70
- // FIXME we cannot deaggregate enums issue: #35186
71
+ let lhs_cast = if adt_def. variants . len ( ) > 1 {
72
+ Lvalue :: Projection ( Box :: new ( LvalueProjection {
73
+ base : lhs. clone ( ) ,
74
+ elem : ProjectionElem :: Downcast ( adt_def, variant) ,
75
+ } ) )
76
+ } else {
77
+ lhs. clone ( )
78
+ } ;
71
79
72
80
let lhs_proj = Lvalue :: Projection ( Box :: new ( LvalueProjection {
73
81
base : lhs_cast,
@@ -80,18 +88,34 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
80
88
debug ! ( "inserting: {:?} @ {:?}" , new_statement, idx + i) ;
81
89
bb. statements . push ( new_statement) ;
82
90
}
91
+
92
+ // if the aggregate was an enum, we need to set the discriminant
93
+ if adt_def. variants . len ( ) > 1 {
94
+ let set_discriminant = Statement {
95
+ kind : StatementKind :: SetDiscriminant {
96
+ lvalue : lhs. clone ( ) ,
97
+ variant_index : variant,
98
+ } ,
99
+ source_info : src_info,
100
+ } ;
101
+ bb. statements . push ( set_discriminant) ;
102
+ } ;
103
+
83
104
curr = bb. statements . len ( ) ;
84
105
bb. statements . extend ( suffix_stmts) ;
85
106
}
86
107
}
87
108
}
88
109
89
- fn get_aggregate_statement < ' a , ' tcx , ' b > ( curr : usize ,
110
+ fn get_aggregate_statement_index < ' a , ' tcx , ' b > ( start : usize ,
90
111
statements : & Vec < Statement < ' tcx > > )
91
112
-> Option < usize > {
92
- for i in curr ..statements. len ( ) {
113
+ for i in start ..statements. len ( ) {
93
114
let ref statement = statements[ i] ;
94
- let StatementKind :: Assign ( _, ref rhs) = statement. kind ;
115
+ let rhs = match statement. kind {
116
+ StatementKind :: Assign ( _, ref rhs) => rhs,
117
+ StatementKind :: SetDiscriminant { .. } => continue ,
118
+ } ;
95
119
let ( kind, operands) = match rhs {
96
120
& Rvalue :: Aggregate ( ref kind, ref operands) => ( kind, operands) ,
97
121
_ => continue ,
@@ -100,9 +124,8 @@ fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
100
124
& AggregateKind :: Adt ( adt_def, variant, _) => ( adt_def, variant) ,
101
125
_ => continue ,
102
126
} ;
103
- if operands. len ( ) == 0 || adt_def . variants . len ( ) > 1 {
127
+ if operands. len ( ) == 0 {
104
128
// don't deaggregate ()
105
- // don't deaggregate enums ... for now
106
129
continue ;
107
130
}
108
131
debug ! ( "getting variant {:?}" , variant) ;
0 commit comments