Skip to content

Commit 83162cd

Browse files
committed
SROA short arrays too.
1 parent 879be36 commit 83162cd

File tree

5 files changed

+135
-36
lines changed

5 files changed

+135
-36
lines changed

compiler/rustc_const_eval/src/transform/validate.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,12 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
748748
format!("invalid empty projection in debuginfo for {:?}", debuginfo.name),
749749
);
750750
}
751-
if !projection.iter().all(|p| matches!(p, PlaceElem::Field(..) | PlaceElem::Deref)) {
751+
if !projection.iter().all(|p| {
752+
matches!(
753+
p,
754+
PlaceElem::Field(..) | PlaceElem::Deref | PlaceElem::ConstantIndex { .. }
755+
)
756+
}) {
752757
self.fail(
753758
START_BLOCK.start_location(),
754759
format!(

compiler/rustc_middle/src/mir/visit.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ macro_rules! make_mir_visitor {
847847
self.visit_ty($(& $mutability)? *ty, TyContext::Location(location));
848848
for elem in projection {
849849
match elem {
850-
ProjectionElem::Deref => {}
850+
ProjectionElem::Deref | ProjectionElem::ConstantIndex { .. } => {}
851851
ProjectionElem::Field(_, ty) => {
852852
self.visit_ty($(& $mutability)? *ty, TyContext::Location(location))
853853
}

compiler/rustc_mir_transform/src/instsimplify.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ use rustc_middle::mir::*;
77
use rustc_middle::ty::layout::ValidityRequirement;
88
use rustc_middle::ty::{self, GenericArgsRef, ParamEnv, Ty, TyCtxt};
99
use rustc_span::symbol::Symbol;
10-
use rustc_target::abi::FieldIdx;
11-
use rustc_target::abi::FIRST_VARIANT;
10+
use rustc_target::abi::{FieldIdx, FIRST_VARIANT};
1211

1312
pub struct InstSimplify;
1413

compiler/rustc_mir_transform/src/sroa.rs

+124-32
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ fn escaping_locals<'tcx>(
9393
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
9494
for (local, decl) in body.local_decls().iter_enumerated() {
9595
if excluded.contains(local) || is_excluded_ty(decl.ty) {
96+
trace!(?local, "early exclusion");
9697
set.insert(local);
9798
}
9899
}
@@ -105,13 +106,17 @@ fn escaping_locals<'tcx>(
105106
}
106107

107108
impl<'tcx> Visitor<'tcx> for EscapeVisitor {
108-
fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
109-
self.set.insert(local);
109+
fn visit_local(&mut self, local: Local, _: PlaceContext, loc: Location) {
110+
if self.set.insert(local) {
111+
trace!(?local, "escapes at {loc:?}");
112+
}
110113
}
111114

112115
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
113116
// Mirror the implementation in PreFlattenVisitor.
114-
if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
117+
if let &[PlaceElem::Field(..) | PlaceElem::ConstantIndex { from_end: false, .. }, ..] =
118+
&place.projection[..]
119+
{
115120
return;
116121
}
117122
self.super_place(place, context, location);
@@ -126,7 +131,7 @@ fn escaping_locals<'tcx>(
126131
if lvalue.as_local().is_some() {
127132
match rvalue {
128133
// Aggregate assignments are expanded in run_pass.
129-
Rvalue::Aggregate(..) | Rvalue::Use(..) => {
134+
Rvalue::Repeat(..) | Rvalue::Aggregate(..) | Rvalue::Use(..) => {
130135
self.visit_rvalue(rvalue, location);
131136
return;
132137
}
@@ -152,32 +157,62 @@ fn escaping_locals<'tcx>(
152157
}
153158
}
154159

160+
#[derive(Copy, Clone, Debug)]
161+
enum LocalMode<'tcx> {
162+
Field(Ty<'tcx>, Local),
163+
Index(Local),
164+
}
165+
166+
impl<'tcx> LocalMode<'tcx> {
167+
fn local(self) -> Local {
168+
match self {
169+
LocalMode::Field(_, l) | LocalMode::Index(l) => l,
170+
}
171+
}
172+
173+
fn elem(self, field: FieldIdx) -> PlaceElem<'tcx> {
174+
match self {
175+
LocalMode::Field(ty, _) => PlaceElem::Field(field, ty),
176+
LocalMode::Index(_) => PlaceElem::ConstantIndex {
177+
offset: field.as_u32() as u64,
178+
min_length: field.as_u32() as u64 + 1,
179+
from_end: false,
180+
},
181+
}
182+
}
183+
}
184+
155185
#[derive(Default, Debug)]
156186
struct ReplacementMap<'tcx> {
157187
/// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
158188
/// and deinit statement and debuginfo.
159-
fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>,
189+
fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<LocalMode<'tcx>>>>>,
160190
}
161191

162192
impl<'tcx> ReplacementMap<'tcx> {
163193
fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
164-
let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else {
194+
let &[first, ref rest @ ..] = place.projection else {
165195
return None;
166196
};
197+
let f = match first {
198+
PlaceElem::Field(f, _) => f,
199+
PlaceElem::ConstantIndex { offset, .. } => FieldIdx::from_u32(offset.try_into().ok()?),
200+
_ => return None,
201+
};
167202
let fields = self.fragments[place.local].as_ref()?;
168-
let (_, new_local) = fields[f]?;
203+
let new_local = fields[f]?.local();
169204
Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) })
170205
}
171206

172207
fn place_fragments(
173208
&self,
174209
place: Place<'tcx>,
175-
) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)> + '_> {
210+
) -> Option<impl Iterator<Item = (FieldIdx, LocalMode<'tcx>)> + '_> {
176211
let local = place.as_local()?;
177212
let fields = self.fragments[local].as_ref()?;
178-
Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
179-
let (ty, local) = opt_ty_local?;
180-
Some((field, ty, local))
213+
Some(fields.iter_enumerated().filter_map(|(field, &local)| {
214+
let local = local?;
215+
Some((field, local))
181216
}))
182217
}
183218
}
@@ -200,15 +235,32 @@ fn compute_flattening<'tcx>(
200235
}
201236
let decl = body.local_decls[local].clone();
202237
let ty = decl.ty;
203-
iter_fields(ty, tcx, param_env, |variant, field, field_ty| {
204-
if variant.is_some() {
205-
// Downcasts are currently not supported.
206-
return;
207-
};
208-
let new_local =
209-
body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
210-
fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
211-
});
238+
if let ty::Array(inner, count) = ty.kind()
239+
&& let Some(count) = count.try_eval_target_usize(tcx, param_env)
240+
&& count <= 64
241+
{
242+
let fragments = fragments.get_or_insert_with(local, IndexVec::new);
243+
for field in 0..(count as u32) {
244+
let new_local =
245+
body.local_decls.push(LocalDecl { ty: *inner, user_ty: None, ..decl.clone() });
246+
fragments.insert(FieldIdx::from_u32(field), LocalMode::Index(new_local));
247+
}
248+
} else {
249+
iter_fields(ty, tcx, param_env, |variant, field, field_ty| {
250+
if variant.is_some() {
251+
// Downcasts are currently not supported.
252+
return;
253+
};
254+
let new_local = body.local_decls.push(LocalDecl {
255+
ty: field_ty,
256+
user_ty: None,
257+
..decl.clone()
258+
});
259+
fragments
260+
.get_or_insert_with(local, IndexVec::new)
261+
.insert(field, LocalMode::Field(field_ty, new_local));
262+
});
263+
}
212264
}
213265
ReplacementMap { fragments }
214266
}
@@ -284,14 +336,16 @@ impl<'tcx> ReplacementVisitor<'tcx, '_> {
284336
let ty = place.ty(self.local_decls, self.tcx).ty;
285337

286338
parts
287-
.map(|(field, field_ty, replacement_local)| {
339+
.map(|(field, replacement_local)| {
288340
let mut var_debug_info = var_debug_info.clone();
289341
let composite = var_debug_info.composite.get_or_insert_with(|| {
290342
Box::new(VarDebugInfoFragment { ty, projection: Vec::new() })
291343
});
292-
composite.projection.push(PlaceElem::Field(field, field_ty));
344+
let elem = replacement_local.elem(field);
345+
composite.projection.push(elem);
293346

294-
var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into());
347+
let local = replacement_local.local();
348+
var_debug_info.value = VarDebugInfoContents::Place(local.into());
295349
var_debug_info
296350
})
297351
.collect()
@@ -318,7 +372,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
318372
// Duplicate storage and deinit statements, as they pretty much apply to all fields.
319373
StatementKind::StorageLive(l) => {
320374
if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
321-
for (_, _, fl) in final_locals {
375+
for (_, fl) in final_locals {
376+
let fl = fl.local();
322377
self.patch.add_statement(location, StatementKind::StorageLive(fl));
323378
}
324379
statement.make_nop();
@@ -327,7 +382,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
327382
}
328383
StatementKind::StorageDead(l) => {
329384
if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
330-
for (_, _, fl) in final_locals {
385+
for (_, fl) in final_locals {
386+
let fl = fl.local();
331387
self.patch.add_statement(location, StatementKind::StorageDead(fl));
332388
}
333389
statement.make_nop();
@@ -336,7 +392,8 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
336392
}
337393
StatementKind::Deinit(box place) => {
338394
if let Some(final_locals) = self.replacements.place_fragments(place) {
339-
for (_, _, fl) in final_locals {
395+
for (_, fl) in final_locals {
396+
let fl = fl.local();
340397
self.patch
341398
.add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
342399
}
@@ -345,6 +402,35 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
345402
}
346403
}
347404

405+
// We have `a = [x; N]`
406+
// We replace it by
407+
// ```
408+
// a_0 = x
409+
// a_1 = x
410+
// ...
411+
// ```
412+
StatementKind::Assign(box (place, Rvalue::Repeat(ref mut operand, _))) => {
413+
if let Some(local) = place.as_local()
414+
&& let Some(final_locals) = &self.replacements.fragments[local]
415+
{
416+
// Replace mentions of SROA'd locals that appear in the operand.
417+
self.visit_operand(&mut *operand, location);
418+
419+
for &new_local in final_locals.iter() {
420+
if let Some(new_local) = new_local {
421+
let new_local = new_local.local();
422+
let rvalue = Rvalue::Use(operand.to_copy());
423+
self.patch.add_statement(
424+
location,
425+
StatementKind::Assign(Box::new((new_local.into(), rvalue))),
426+
);
427+
}
428+
}
429+
statement.make_nop();
430+
return;
431+
}
432+
}
433+
348434
// We have `a = Struct { 0: x, 1: y, .. }`.
349435
// We replace it by
350436
// ```
@@ -358,8 +444,10 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
358444
{
359445
// This is ok as we delete the statement later.
360446
let operands = std::mem::take(operands);
361-
for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
362-
if let Some((_, new_local)) = opt_ty_local {
447+
for (&new_local, mut operand) in final_locals.iter().zip(operands) {
448+
if let Some(new_local) = new_local {
449+
let new_local = new_local.local();
450+
363451
// Replace mentions of SROA'd locals that appear in the operand.
364452
self.visit_operand(&mut operand, location);
365453

@@ -387,8 +475,10 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
387475
if let Some(final_locals) = self.replacements.place_fragments(place) {
388476
// Put the deaggregated statements *after* the original one.
389477
let location = location.successor_within_block();
390-
for (field, ty, new_local) in final_locals {
391-
let rplace = self.tcx.mk_place_field(place, field, ty);
478+
for (field, new_local) in final_locals {
479+
let elem = new_local.elem(field);
480+
let new_local = new_local.local();
481+
let rplace = self.tcx.mk_place_elem(place, elem);
392482
let rvalue = Rvalue::Use(Operand::Move(rplace));
393483
self.patch.add_statement(
394484
location,
@@ -414,8 +504,10 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
414504
Operand::Constant(_) => bug!(),
415505
};
416506
if let Some(final_locals) = self.replacements.place_fragments(lhs) {
417-
for (field, ty, new_local) in final_locals {
418-
let rplace = self.tcx.mk_place_field(rplace, field, ty);
507+
for (field, new_local) in final_locals {
508+
let elem = new_local.elem(field);
509+
let new_local = new_local.local();
510+
let rplace = self.tcx.mk_place_elem(rplace, elem);
419511
debug!(?rplace);
420512
let rplace = self
421513
.replacements

tests/mir-opt/sroa/lifetimes.foo.ScalarReplacementOfAggregates.diff

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
let mut _30: isize;
3232
+ let _31: std::result::Result<std::boxed::Box<dyn std::fmt::Display>, <T as Err>::Err>;
3333
+ let _32: u32;
34+
+ let _33: &str;
35+
+ let _34: &str;
36+
+ let _35: &str;
3437
scope 1 {
3538
- debug foo => _1;
3639
+ debug ((foo: Foo<T>).0: std::result::Result<std::boxed::Box<dyn std::fmt::Display>, <T as Err>::Err>) => _31;

0 commit comments

Comments
 (0)