Skip to content

Commit e64f688

Browse files
authored
Auto merge of #35348 - scottcarr:discriminant2, r=nikomatsakis
[MIR] Add explicit SetDiscriminant StatementKind for deaggregating enums cc #35186 To deaggregate enums, we need to be able to explicitly set the discriminant. This PR implements a new StatementKind that does that. I think some of the places that have `panics!` now could maybe do something smarter.
2 parents d3c3de8 + d77a136 commit e64f688

File tree

12 files changed

+156
-18
lines changed

12 files changed

+156
-18
lines changed

src/librustc/mir/repr.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -689,13 +689,17 @@ pub struct Statement<'tcx> {
689689
#[derive(Clone, Debug, RustcEncodable, RustcDecodable)]
690690
pub enum StatementKind<'tcx> {
691691
Assign(Lvalue<'tcx>, Rvalue<'tcx>),
692+
SetDiscriminant{ lvalue: Lvalue<'tcx>, variant_index: usize },
692693
}
693694

694695
impl<'tcx> Debug for Statement<'tcx> {
695696
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
696697
use self::StatementKind::*;
697698
match self.kind {
698-
Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv)
699+
Assign(ref lv, ref rv) => write!(fmt, "{:?} = {:?}", lv, rv),
700+
SetDiscriminant{lvalue: ref lv, variant_index: index} => {
701+
write!(fmt, "discriminant({:?}) = {:?}", lv, index)
702+
}
699703
}
700704
}
701705
}

src/librustc/mir/visit.rs

+3
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ macro_rules! make_mir_visitor {
323323
ref $($mutability)* rvalue) => {
324324
self.visit_assign(block, lvalue, rvalue);
325325
}
326+
StatementKind::SetDiscriminant{ ref $($mutability)* lvalue, .. } => {
327+
self.visit_lvalue(lvalue, LvalueContext::Store);
328+
}
326329
}
327330
}
328331

src/librustc_borrowck/borrowck/mir/dataflow/impls.rs

+3
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ impl<'a, 'tcx> BitDenotation for MovingOutStatements<'a, 'tcx> {
442442
}
443443
let bits_per_block = self.bits_per_block(ctxt);
444444
match stmt.kind {
445+
repr::StatementKind::SetDiscriminant { .. } => {
446+
span_bug!(stmt.source_info.span, "SetDiscriminant should not exist in borrowck");
447+
}
445448
repr::StatementKind::Assign(ref lvalue, _) => {
446449
// assigning into this `lvalue` kills all
447450
// MoveOuts from it, and *also* all MoveOuts

src/librustc_borrowck/borrowck/mir/dataflow/sanity_check.rs

+3
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ fn each_block<'a, 'tcx, O>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
104104
repr::StatementKind::Assign(ref lvalue, ref rvalue) => {
105105
(lvalue, rvalue)
106106
}
107+
repr::StatementKind::SetDiscriminant{ .. } =>
108+
span_bug!(stmt.source_info.span,
109+
"sanity_check should run before Deaggregator inserts SetDiscriminant"),
107110
};
108111

109112
if lvalue == peek_arg_lval {

src/librustc_borrowck/borrowck/mir/gather_moves.rs

+4
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,10 @@ fn gather_moves<'a, 'tcx>(mir: &Mir<'tcx>, tcx: TyCtxt<'a, 'tcx, 'tcx>) -> MoveD
616616
Rvalue::InlineAsm { .. } => {}
617617
}
618618
}
619+
StatementKind::SetDiscriminant{ .. } => {
620+
span_bug!(stmt.source_info.span,
621+
"SetDiscriminant should not exist during borrowck");
622+
}
619623
}
620624
}
621625

src/librustc_borrowck/borrowck/mir/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ fn drop_flag_effects_for_location<'a, 'tcx, F>(
369369
let block = &mir[loc.block];
370370
match block.statements.get(loc.index) {
371371
Some(stmt) => match stmt.kind {
372+
repr::StatementKind::SetDiscriminant{ .. } => {
373+
span_bug!(stmt.source_info.span, "SetDiscrimant should not exist during borrowck");
374+
}
372375
repr::StatementKind::Assign(ref lvalue, _) => {
373376
debug!("drop_flag_effects: assignment {:?}", stmt);
374377
on_all_children_bits(tcx, mir, move_data,

src/librustc_mir/transform/deaggregator.rs

+34-11
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
3939

4040
let mut curr: usize = 0;
4141
for bb in mir.basic_blocks_mut() {
42-
let idx = match get_aggregate_statement(curr, &bb.statements) {
42+
let idx = match get_aggregate_statement_index(curr, &bb.statements) {
4343
Some(idx) => idx,
4444
None => continue,
4545
};
@@ -48,7 +48,11 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
4848
let src_info = bb.statements[idx].source_info;
4949
let suffix_stmts = bb.statements.split_off(idx+1);
5050
let orig_stmt = bb.statements.pop().unwrap();
51-
let StatementKind::Assign(ref lhs, ref rhs) = orig_stmt.kind;
51+
let (lhs, rhs) = match orig_stmt.kind {
52+
StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
53+
StatementKind::SetDiscriminant{ .. } =>
54+
span_bug!(src_info.span, "expected aggregate, not {:?}", orig_stmt.kind),
55+
};
5256
let (agg_kind, operands) = match rhs {
5357
&Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
5458
_ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
@@ -64,10 +68,14 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
6468
let ty = variant_def.fields[i].ty(tcx, substs);
6569
let rhs = Rvalue::Use(op.clone());
6670

67-
// since we don't handle enums, we don't need a cast
68-
let lhs_cast = lhs.clone();
69-
70-
// FIXME we cannot deaggregate enums issue: #35186
71+
let lhs_cast = if adt_def.variants.len() > 1 {
72+
Lvalue::Projection(Box::new(LvalueProjection {
73+
base: lhs.clone(),
74+
elem: ProjectionElem::Downcast(adt_def, variant),
75+
}))
76+
} else {
77+
lhs.clone()
78+
};
7179

7280
let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection {
7381
base: lhs_cast,
@@ -80,18 +88,34 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
8088
debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
8189
bb.statements.push(new_statement);
8290
}
91+
92+
// if the aggregate was an enum, we need to set the discriminant
93+
if adt_def.variants.len() > 1 {
94+
let set_discriminant = Statement {
95+
kind: StatementKind::SetDiscriminant {
96+
lvalue: lhs.clone(),
97+
variant_index: variant,
98+
},
99+
source_info: src_info,
100+
};
101+
bb.statements.push(set_discriminant);
102+
};
103+
83104
curr = bb.statements.len();
84105
bb.statements.extend(suffix_stmts);
85106
}
86107
}
87108
}
88109

89-
fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
110+
fn get_aggregate_statement_index<'a, 'tcx, 'b>(start: usize,
90111
statements: &Vec<Statement<'tcx>>)
91112
-> Option<usize> {
92-
for i in curr..statements.len() {
113+
for i in start..statements.len() {
93114
let ref statement = statements[i];
94-
let StatementKind::Assign(_, ref rhs) = statement.kind;
115+
let rhs = match statement.kind {
116+
StatementKind::Assign(_, ref rhs) => rhs,
117+
StatementKind::SetDiscriminant{ .. } => continue,
118+
};
95119
let (kind, operands) = match rhs {
96120
&Rvalue::Aggregate(ref kind, ref operands) => (kind, operands),
97121
_ => continue,
@@ -100,9 +124,8 @@ fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
100124
&AggregateKind::Adt(adt_def, variant, _) => (adt_def, variant),
101125
_ => continue,
102126
};
103-
if operands.len() == 0 || adt_def.variants.len() > 1 {
127+
if operands.len() == 0 {
104128
// don't deaggregate ()
105-
// don't deaggregate enums ... for now
106129
continue;
107130
}
108131
debug!("getting variant {:?}", variant);

src/librustc_mir/transform/promote_consts.rs

+19-3
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,13 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
219219
let (mut rvalue, mut call) = (None, None);
220220
let source_info = if stmt_idx < no_stmts {
221221
let statement = &mut self.source[bb].statements[stmt_idx];
222-
let StatementKind::Assign(_, ref mut rhs) = statement.kind;
222+
let mut rhs = match statement.kind {
223+
StatementKind::Assign(_, ref mut rhs) => rhs,
224+
StatementKind::SetDiscriminant{ .. } =>
225+
span_bug!(statement.source_info.span,
226+
"cannot promote SetDiscriminant {:?}",
227+
statement),
228+
};
223229
if self.keep_original {
224230
rvalue = Some(rhs.clone());
225231
} else {
@@ -300,10 +306,16 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
300306
});
301307
let mut rvalue = match candidate {
302308
Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
303-
match self.source[bb].statements[stmt_idx].kind {
309+
let ref mut statement = self.source[bb].statements[stmt_idx];
310+
match statement.kind {
304311
StatementKind::Assign(_, ref mut rvalue) => {
305312
mem::replace(rvalue, Rvalue::Use(new_operand))
306313
}
314+
StatementKind::SetDiscriminant{ .. } => {
315+
span_bug!(statement.source_info.span,
316+
"cannot promote SetDiscriminant {:?}",
317+
statement);
318+
}
307319
}
308320
}
309321
Candidate::ShuffleIndices(bb) => {
@@ -340,7 +352,11 @@ pub fn promote_candidates<'a, 'tcx>(mir: &mut Mir<'tcx>,
340352
let (span, ty) = match candidate {
341353
Candidate::Ref(Location { block: bb, statement_index: stmt_idx }) => {
342354
let statement = &mir[bb].statements[stmt_idx];
343-
let StatementKind::Assign(ref dest, _) = statement.kind;
355+
let dest = match statement.kind {
356+
StatementKind::Assign(ref dest, _) => dest,
357+
StatementKind::SetDiscriminant{ .. } =>
358+
panic!("cannot promote SetDiscriminant"),
359+
};
344360
if let Lvalue::Temp(index) = *dest {
345361
if temps[index] == TempState::PromotedOut {
346362
// Already promoted.

src/librustc_mir/transform/type_check.rs

+20-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
use rustc::infer::{self, InferCtxt, InferOk};
1515
use rustc::traits::{self, Reveal};
1616
use rustc::ty::fold::TypeFoldable;
17-
use rustc::ty::{self, Ty, TyCtxt};
17+
use rustc::ty::{self, Ty, TyCtxt, TypeVariants};
1818
use rustc::mir::repr::*;
1919
use rustc::mir::tcx::LvalueTy;
2020
use rustc::mir::transform::{MirPass, MirSource, Pass};
@@ -360,10 +360,27 @@ impl<'a, 'gcx, 'tcx> TypeChecker<'a, 'gcx, 'tcx> {
360360
span_mirbug!(self, stmt, "bad assignment ({:?} = {:?}): {:?}",
361361
lv_ty, rv_ty, terr);
362362
}
363-
}
364-
365363
// FIXME: rvalue with undeterminable type - e.g. inline
366364
// asm.
365+
}
366+
}
367+
StatementKind::SetDiscriminant{ ref lvalue, variant_index } => {
368+
let lvalue_type = lvalue.ty(mir, tcx).to_ty(tcx);
369+
let adt = match lvalue_type.sty {
370+
TypeVariants::TyEnum(adt, _) => adt,
371+
_ => {
372+
span_bug!(stmt.source_info.span,
373+
"bad set discriminant ({:?} = {:?}): lhs is not an enum",
374+
lvalue,
375+
variant_index);
376+
}
377+
};
378+
if variant_index >= adt.variants.len() {
379+
span_bug!(stmt.source_info.span,
380+
"bad set discriminant ({:?} = {:?}): value of of range",
381+
lvalue,
382+
variant_index);
383+
};
367384
}
368385
}
369386
}

src/librustc_trans/mir/constant.rs

+3
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ impl<'a, 'tcx> MirConstContext<'a, 'tcx> {
285285
Err(err) => if failure.is_ok() { failure = Err(err); }
286286
}
287287
}
288+
mir::StatementKind::SetDiscriminant{ .. } => {
289+
span_bug!(span, "SetDiscriminant should not appear in constants?");
290+
}
288291
}
289292
}
290293

src/librustc_trans/mir/statement.rs

+14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use common::{self, BlockAndBuilder};
1414

1515
use super::MirContext;
1616
use super::LocalRef;
17+
use super::super::adt;
18+
use super::super::disr::Disr;
1719

1820
impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
1921
pub fn trans_statement(&mut self,
@@ -57,6 +59,18 @@ impl<'bcx, 'tcx> MirContext<'bcx, 'tcx> {
5759
self.trans_rvalue(bcx, tr_dest, rvalue, debug_loc)
5860
}
5961
}
62+
mir::StatementKind::SetDiscriminant{ref lvalue, variant_index} => {
63+
let ty = self.monomorphized_lvalue_ty(lvalue);
64+
let repr = adt::represent_type(bcx.ccx(), ty);
65+
let lvalue_transed = self.trans_lvalue(&bcx, lvalue);
66+
bcx.with_block(|bcx|
67+
adt::trans_set_discr(bcx,
68+
&repr,
69+
lvalue_transed.llval,
70+
Disr::from(variant_index))
71+
);
72+
bcx
73+
}
6074
}
6175
}
6276
}
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// http://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
enum Baz {
12+
Empty,
13+
Foo { x: usize },
14+
}
15+
16+
fn bar(a: usize) -> Baz {
17+
Baz::Foo { x: a }
18+
}
19+
20+
fn main() {
21+
let x = bar(10);
22+
match x {
23+
Baz::Empty => println!("empty"),
24+
Baz::Foo { x } => println!("{}", x),
25+
};
26+
}
27+
28+
// END RUST SOURCE
29+
// START rustc.node10.Deaggregator.before.mir
30+
// bb0: {
31+
// var0 = arg0; // scope 0 at main.rs:7:8: 7:9
32+
// tmp0 = var0; // scope 1 at main.rs:8:19: 8:20
33+
// return = Baz::Foo { x: tmp0 }; // scope 1 at main.rs:8:5: 8:21
34+
// goto -> bb1; // scope 1 at main.rs:7:1: 9:2
35+
// }
36+
// END rustc.node10.Deaggregator.before.mir
37+
// START rustc.node10.Deaggregator.after.mir
38+
// bb0: {
39+
// var0 = arg0; // scope 0 at main.rs:7:8: 7:9
40+
// tmp0 = var0; // scope 1 at main.rs:8:19: 8:20
41+
// ((return as Foo).0: usize) = tmp0; // scope 1 at main.rs:8:5: 8:21
42+
// discriminant(return) = 1; // scope 1 at main.rs:8:5: 8:21
43+
// goto -> bb1; // scope 1 at main.rs:7:1: 9:2
44+
// }
45+
// END rustc.node10.Deaggregator.after.mir

0 commit comments

Comments
 (0)