Skip to content

Commit 82b6247

Browse files
committed
Propagate places through assignments.
1 parent f36b838 commit 82b6247

13 files changed

+177
-33
lines changed

compiler/rustc_mir_dataflow/src/value_analysis.rs

+109-10
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use std::fmt::{Debug, Formatter};
3636
use std::ops::Range;
3737

3838
use rustc_data_structures::captures::Captures;
39-
use rustc_data_structures::fx::{FxHashMap, StdEntry};
39+
use rustc_data_structures::fx::{FxHashMap, FxIndexSet, StdEntry};
4040
use rustc_data_structures::stack::ensure_sufficient_stack;
4141
use rustc_index::bit_set::BitSet;
4242
use rustc_index::IndexVec;
@@ -796,7 +796,52 @@ impl<'tcx> Map<'tcx> {
796796
self.locals[local] = Some(place);
797797
}
798798

799-
PlaceCollector { tcx, body, map: self }.visit_body(body);
799+
// Collect syntactic places and assignments between them.
800+
let mut collector =
801+
PlaceCollector { tcx, body, map: self, assignments: Default::default() };
802+
collector.visit_body(body);
803+
let PlaceCollector { mut assignments, .. } = collector;
804+
805+
// Just collecting syntactic places is not enough. We may need to propagate this pattern:
806+
// _1 = (const 5u32, const 13i64);
807+
// _2 = _1;
808+
// _3 = (_2.0 as u32);
809+
//
810+
// `_1.0` does not appear, but we still need to track it. This is achieved by propagating
811+
// projections from assignments. We recorded an assignment between `_2` and `_1`, so we
812+
// want `_1` and `_2` to have the same sub-places.
813+
//
814+
// This is what this fixpoint loop does. While we are still creating places, run through
815+
// all the assignments, and register places for children.
816+
let mut num_places = 0;
817+
while num_places < self.places.len() {
818+
num_places = self.places.len();
819+
820+
for assign in 0.. {
821+
let Some(&(lhs, rhs)) = assignments.get_index(assign) else { break };
822+
823+
// Mirror children from `lhs` in `rhs`.
824+
let mut child = self.places[lhs].first_child;
825+
while let Some(lhs_child) = child {
826+
let PlaceInfo { ty, proj_elem, next_sibling, .. } = self.places[lhs_child];
827+
let rhs_child =
828+
self.register_place(ty, rhs, proj_elem.expect("child is not a projection"));
829+
assignments.insert((lhs_child, rhs_child));
830+
child = next_sibling;
831+
}
832+
833+
// Conversely, mirror children from `rhs` in `lhs`.
834+
let mut child = self.places[rhs].first_child;
835+
while let Some(rhs_child) = child {
836+
let PlaceInfo { ty, proj_elem, next_sibling, .. } = self.places[rhs_child];
837+
let lhs_child =
838+
self.register_place(ty, lhs, proj_elem.expect("child is not a projection"));
839+
assignments.insert((lhs_child, rhs_child));
840+
child = next_sibling;
841+
}
842+
}
843+
}
844+
drop(assignments);
800845

801846
// Create values for places whose type have scalar layout.
802847
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
@@ -879,17 +924,14 @@ struct PlaceCollector<'a, 'b, 'tcx> {
879924
tcx: TyCtxt<'tcx>,
880925
body: &'b Body<'tcx>,
881926
map: &'a mut Map<'tcx>,
927+
assignments: FxIndexSet<(PlaceIndex, PlaceIndex)>,
882928
}
883929

884-
impl<'tcx> Visitor<'tcx> for PlaceCollector<'_, '_, 'tcx> {
930+
impl<'tcx> PlaceCollector<'_, '_, 'tcx> {
885931
#[tracing::instrument(level = "trace", skip(self))]
886-
fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, _: Location) {
887-
if !ctxt.is_use() {
888-
return;
889-
}
890-
932+
fn register_place(&mut self, place: Place<'tcx>) -> Option<PlaceIndex> {
891933
// Create a place for this projection.
892-
let Some(mut place_index) = self.map.locals[place.local] else { return };
934+
let mut place_index = self.map.locals[place.local]?;
893935
let mut ty = PlaceTy::from_ty(self.body.local_decls[place.local].ty);
894936
tracing::trace!(?place_index, ?ty);
895937

@@ -903,7 +945,7 @@ impl<'tcx> Visitor<'tcx> for PlaceCollector<'_, '_, 'tcx> {
903945
}
904946

905947
for proj in place.projection {
906-
let Ok(track_elem) = proj.try_into() else { return };
948+
let track_elem = proj.try_into().ok()?;
907949
ty = ty.projection_ty(self.tcx, proj);
908950
place_index = self.map.register_place(ty.ty, place_index, track_elem);
909951
tracing::trace!(?proj, ?place_index, ?ty);
@@ -917,6 +959,63 @@ impl<'tcx> Visitor<'tcx> for PlaceCollector<'_, '_, 'tcx> {
917959
self.map.register_place(discriminant_ty, place_index, TrackElem::Discriminant);
918960
}
919961
}
962+
963+
Some(place_index)
964+
}
965+
}
966+
967+
impl<'tcx> Visitor<'tcx> for PlaceCollector<'_, '_, 'tcx> {
968+
#[tracing::instrument(level = "trace", skip(self))]
969+
fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, _: Location) {
970+
if !ctxt.is_use() {
971+
return;
972+
}
973+
974+
self.register_place(*place);
975+
}
976+
977+
fn visit_assign(&mut self, lhs: &Place<'tcx>, rhs: &Rvalue<'tcx>, location: Location) {
978+
self.super_assign(lhs, rhs, location);
979+
980+
match rhs {
981+
Rvalue::Use(Operand::Move(rhs) | Operand::Copy(rhs)) | Rvalue::CopyForDeref(rhs) => {
982+
let Some(lhs) = self.register_place(*lhs) else { return };
983+
let Some(rhs) = self.register_place(*rhs) else { return };
984+
self.assignments.insert((lhs, rhs));
985+
}
986+
Rvalue::Aggregate(kind, fields) => {
987+
let Some(mut lhs) = self.register_place(*lhs) else { return };
988+
match **kind {
989+
// Do not propagate unions.
990+
AggregateKind::Adt(_, _, _, _, Some(_)) => return,
991+
AggregateKind::Adt(_, variant, _, _, None) => {
992+
let ty = self.map.places[lhs].ty;
993+
if ty.is_enum() {
994+
lhs = self.map.register_place(ty, lhs, TrackElem::Variant(variant));
995+
}
996+
}
997+
AggregateKind::RawPtr(..)
998+
| AggregateKind::Array(_)
999+
| AggregateKind::Tuple
1000+
| AggregateKind::Closure(..)
1001+
| AggregateKind::Coroutine(..)
1002+
| AggregateKind::CoroutineClosure(..) => {}
1003+
}
1004+
for (index, field) in fields.iter_enumerated() {
1005+
if let Some(rhs) = field.place()
1006+
&& let Some(rhs) = self.register_place(rhs)
1007+
{
1008+
let lhs = self.map.register_place(
1009+
self.map.places[rhs].ty,
1010+
lhs,
1011+
TrackElem::Field(index),
1012+
);
1013+
self.assignments.insert((lhs, rhs));
1014+
}
1015+
}
1016+
}
1017+
_ => {}
1018+
}
9201019
}
9211020
}
9221021

tests/mir-opt/dataflow-const-prop/aggregate_copy.foo.DataflowConstProp.diff

+16-6
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,25 @@
2222
StorageLive(_1);
2323
_1 = const Foo;
2424
StorageLive(_2);
25-
_2 = _1;
25+
- _2 = _1;
26+
+ _2 = const (5_u32, 3_u32);
2627
StorageLive(_3);
27-
_3 = (_2.1: u32);
28+
- _3 = (_2.1: u32);
29+
+ _3 = const 3_u32;
2830
StorageLive(_4);
2931
StorageLive(_5);
30-
_5 = _3;
31-
_4 = Ge(move _5, const 2_u32);
32-
switchInt(move _4) -> [0: bb2, otherwise: bb1];
32+
- _5 = _3;
33+
- _4 = Ge(move _5, const 2_u32);
34+
- switchInt(move _4) -> [0: bb2, otherwise: bb1];
35+
+ _5 = const 3_u32;
36+
+ _4 = const true;
37+
+ switchInt(const true) -> [0: bb2, otherwise: bb1];
3338
}
3439

3540
bb1: {
3641
StorageDead(_5);
37-
_0 = (_2.0: u32);
42+
- _0 = (_2.0: u32);
43+
+ _0 = const 5_u32;
3844
goto -> bb3;
3945
}
4046

@@ -51,5 +57,9 @@
5157
StorageDead(_1);
5258
return;
5359
}
60+
+ }
61+
+
62+
+ ALLOC0 (size: 8, align: 4) {
63+
+ 05 00 00 00 03 00 00 00 │ ........
5464
}
5565

tests/mir-opt/dataflow-const-prop/aggregate_copy.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ fn foo() -> u32 {
1212

1313
// CHECK:bb0: {
1414
// CHECK: [[a]] = const Foo;
15-
// CHECK: [[b]] = [[a]];
16-
// CHECK: [[c]] = ([[b]].1: u32);
17-
// CHECK: switchInt(move {{_.*}}) -> [0: bb2, otherwise: bb1];
15+
// CHECK: [[b]] = const (5_u32, 3_u32);
16+
// CHECK: [[c]] = const 3_u32;
17+
// CHECK: {{_.*}} = const 3_u32;
18+
// CHECK: {{_.*}} = const true;
19+
// CHECK: switchInt(const true) -> [0: bb2, otherwise: bb1];
1820

1921
// CHECK:bb1: {
20-
// CHECK: _0 = ([[b]].0: u32);
22+
// CHECK: _0 = const 5_u32;
2123
// CHECK: goto -> bb3;
2224

2325
// CHECK:bb2: {

tests/mir-opt/dataflow-const-prop/repr_transparent.main.DataflowConstProp.diff

+5-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
StorageDead(_5);
3333
StorageDead(_4);
3434
- _2 = I32(move _3);
35-
+ _2 = I32(const 0_i32);
35+
+ _2 = const I32(0_i32);
3636
StorageDead(_3);
3737
_0 = const ();
3838
StorageDead(_2);
@@ -42,6 +42,10 @@
4242
+ }
4343
+
4444
+ ALLOC0 (size: 4, align: 4) {
45+
+ 00 00 00 00 │ ....
46+
+ }
47+
+
48+
+ ALLOC1 (size: 4, align: 4) {
4549
+ 00 00 00 00 │ ....
4650
}
4751

tests/mir-opt/dataflow-const-prop/repr_transparent.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ fn main() {
1515
// CHECK: [[x]] = const I32(0_i32);
1616
let x = I32(0);
1717

18-
// CHECK: [[y]] = I32(const 0_i32);
18+
// CHECK: [[y]] = const I32(0_i32);
1919
let y = I32(x.0 + x.0);
2020
}

tests/mir-opt/dataflow-const-prop/transmute.less_as_i8.DataflowConstProp.32bit.diff

+12-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@
77

88
bb0: {
99
StorageLive(_1);
10-
_1 = Less;
11-
_0 = move _1 as i8 (Transmute);
10+
- _1 = Less;
11+
- _0 = move _1 as i8 (Transmute);
12+
+ _1 = const Less;
13+
+ _0 = const std::cmp::Ordering::Less as i8 (Transmute);
1214
StorageDead(_1);
1315
return;
1416
}
17+
+ }
18+
+
19+
+ ALLOC0 (size: 1, align: 1) {
20+
+ ff │ .
21+
+ }
22+
+
23+
+ ALLOC1 (size: 1, align: 1) {
24+
+ ff │ .
1525
}
1626

tests/mir-opt/dataflow-const-prop/transmute.less_as_i8.DataflowConstProp.64bit.diff

+12-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@
77

88
bb0: {
99
StorageLive(_1);
10-
_1 = Less;
11-
_0 = move _1 as i8 (Transmute);
10+
- _1 = Less;
11+
- _0 = move _1 as i8 (Transmute);
12+
+ _1 = const Less;
13+
+ _0 = const std::cmp::Ordering::Less as i8 (Transmute);
1214
StorageDead(_1);
1315
return;
1416
}
17+
+ }
18+
+
19+
+ ALLOC0 (size: 1, align: 1) {
20+
+ ff │ .
21+
+ }
22+
+
23+
+ ALLOC1 (size: 1, align: 1) {
24+
+ ff │ .
1525
}
1626

tests/mir-opt/dataflow-const-prop/tuple.main.DataflowConstProp.32bit.diff

+5-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
- _14 = _6;
8282
- _11 = (move _12, move _13, move _14);
8383
+ _14 = const 11_i32;
84-
+ _11 = (const 6_i32, move _13, const 11_i32);
84+
+ _11 = (const 6_i32, const (2_i32, 3_i32), const 11_i32);
8585
StorageDead(_14);
8686
StorageDead(_13);
8787
StorageDead(_12);
@@ -103,6 +103,10 @@
103103
+ }
104104
+
105105
+ ALLOC2 (size: 8, align: 4) {
106+
+ 02 00 00 00 03 00 00 00 │ ........
107+
+ }
108+
+
109+
+ ALLOC3 (size: 8, align: 4) {
106110
+ 01 00 00 00 02 00 00 00 │ ........
107111
}
108112

tests/mir-opt/dataflow-const-prop/tuple.main.DataflowConstProp.64bit.diff

+5-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
- _14 = _6;
8282
- _11 = (move _12, move _13, move _14);
8383
+ _14 = const 11_i32;
84-
+ _11 = (const 6_i32, move _13, const 11_i32);
84+
+ _11 = (const 6_i32, const (2_i32, 3_i32), const 11_i32);
8585
StorageDead(_14);
8686
StorageDead(_13);
8787
StorageDead(_12);
@@ -103,6 +103,10 @@
103103
+ }
104104
+
105105
+ ALLOC2 (size: 8, align: 4) {
106+
+ 02 00 00 00 03 00 00 00 │ ........
107+
+ }
108+
+
109+
+ ALLOC3 (size: 8, align: 4) {
106110
+ 01 00 00 00 02 00 00 00 │ ........
107111
}
108112

tests/mir-opt/dataflow-const-prop/tuple.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ fn main() {
2222
// CHECK: [[c]] = const 11_i32;
2323
let c = a.0 + a.1 + b;
2424

25-
// CHECK: [[a2:_.*]] = const (2_i32, 3_i32);
26-
// CHECK: [[d]] = (const 6_i32, move [[a2]], const 11_i32);
25+
// CHECK: [[d]] = (const 6_i32, const (2_i32, 3_i32), const 11_i32);
2726
let d = (b, a, c);
2827
}

tests/mir-opt/jump_threading.aggregate_copy.JumpThreading.panic-abort.diff

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
StorageLive(_5);
3030
_5 = _3;
3131
_4 = Eq(move _5, const 2_u32);
32-
switchInt(move _4) -> [0: bb2, otherwise: bb1];
32+
- switchInt(move _4) -> [0: bb2, otherwise: bb1];
33+
+ goto -> bb2;
3334
}
3435

3536
bb1: {

tests/mir-opt/jump_threading.aggregate_copy.JumpThreading.panic-unwind.diff

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
StorageLive(_5);
3030
_5 = _3;
3131
_4 = Eq(move _5, const 2_u32);
32-
switchInt(move _4) -> [0: bb2, otherwise: bb1];
32+
- switchInt(move _4) -> [0: bb2, otherwise: bb1];
33+
+ goto -> bb2;
3334
}
3435

3536
bb1: {

tests/mir-opt/jump_threading.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ fn assume(a: u8, b: bool) -> u8 {
509509
/// Verify that jump threading succeeds seeing through copies of aggregates.
510510
fn aggregate_copy() -> u32 {
511511
// CHECK-LABEL: fn aggregate_copy(
512-
// CHECK: switchInt(
512+
// CHECK-NOT: switchInt(
513513

514514
const Foo: (u32, u32) = (5, 3);
515515

0 commit comments

Comments
 (0)