Skip to content

Commit 2985618

Browse files
committed
Simplify ReplacementMap.
1 parent 8e05ab0 commit 2985618

File tree

1 file changed

+109
-84
lines changed
  • compiler/rustc_mir_transform/src

1 file changed

+109
-84
lines changed

compiler/rustc_mir_transform/src/sroa.rs

+109-84
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
use crate::MirPass;
2-
use rustc_data_structures::fx::FxIndexMap;
32
use rustc_index::bit_set::BitSet;
43
use rustc_index::vec::IndexVec;
54
use rustc_middle::mir::patch::MirPatch;
65
use rustc_middle::mir::visit::*;
76
use rustc_middle::mir::*;
8-
use rustc_middle::ty::TyCtxt;
7+
use rustc_middle::ty::{Ty, TyCtxt};
98
use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
109

1110
pub struct ScalarReplacementOfAggregates;
@@ -26,13 +25,13 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
2625
let replacements = compute_flattening(tcx, body, escaping);
2726
debug!(?replacements);
2827
let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
29-
if !all_dead_locals.is_empty() && tcx.sess.mir_opt_level() >= 4 {
28+
if !all_dead_locals.is_empty() {
3029
for local in excluded.indices() {
31-
excluded[local] |= all_dead_locals.contains(local) ;
30+
excluded[local] |= all_dead_locals.contains(local);
3231
}
3332
excluded.raw.resize(body.local_decls.len(), false);
3433
} else {
35-
break
34+
break;
3635
}
3736
}
3837
}
@@ -111,36 +110,29 @@ fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<
111110

112111
#[derive(Default, Debug)]
113112
struct ReplacementMap<'tcx> {
114-
fields: FxIndexMap<PlaceRef<'tcx>, Local>,
115113
/// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
116114
/// and deinit statement and debuginfo.
117-
fragments: IndexVec<Local, Option<Vec<(&'tcx [PlaceElem<'tcx>], Local)>>>,
115+
fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
118116
}
119117

120118
impl<'tcx> ReplacementMap<'tcx> {
121-
fn gather_debug_info_fragments(
122-
&self,
123-
place: PlaceRef<'tcx>,
124-
) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
125-
let mut fragments = Vec::new();
126-
let Some(parts) = &self.fragments[place.local] else { return None };
127-
for (proj, replacement_local) in parts {
128-
if proj.starts_with(place.projection) {
129-
fragments.push(VarDebugInfoFragment {
130-
projection: proj[place.projection.len()..].to_vec(),
131-
contents: Place::from(*replacement_local),
132-
});
133-
}
134-
}
135-
Some(fragments)
119+
fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
120+
let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
121+
let fields = self.fragments[place.local].as_ref()?;
122+
let (_, new_local) = fields[f]?;
123+
Some(Place { local: new_local, projection: tcx.intern_place_elems(&rest) })
136124
}
137125

138126
fn place_fragments(
139127
&self,
140128
place: Place<'tcx>,
141-
) -> Option<&Vec<(&'tcx [PlaceElem<'tcx>], Local)>> {
129+
) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
142130
let local = place.as_local()?;
143-
self.fragments[local].as_ref()
131+
let fields = self.fragments[local].as_ref()?;
132+
Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
133+
let (ty, local) = opt_ty_local?;
134+
Some((field, ty, local))
135+
}))
144136
}
145137
}
146138

@@ -153,8 +145,7 @@ fn compute_flattening<'tcx>(
153145
body: &mut Body<'tcx>,
154146
escaping: BitSet<Local>,
155147
) -> ReplacementMap<'tcx> {
156-
let mut fields = FxIndexMap::default();
157-
let mut fragments = IndexVec::from_elem(None::<Vec<_>>, &body.local_decls);
148+
let mut fragments = IndexVec::from_elem(None, &body.local_decls);
158149

159150
for local in body.local_decls.indices() {
160151
if escaping.contains(local) {
@@ -169,14 +160,10 @@ fn compute_flattening<'tcx>(
169160
};
170161
let new_local =
171162
body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
172-
let place = Place::from(local)
173-
.project_deeper(&[PlaceElem::Field(field, field_ty)], tcx)
174-
.as_ref();
175-
fields.insert(place, new_local);
176-
fragments[local].get_or_insert_default().push((place.projection, new_local));
163+
fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
177164
});
178165
}
179-
ReplacementMap { fields, fragments }
166+
ReplacementMap { fragments }
180167
}
181168

182169
/// Perform the replacement computed by `compute_flattening`.
@@ -186,8 +173,10 @@ fn replace_flattened_locals<'tcx>(
186173
replacements: ReplacementMap<'tcx>,
187174
) -> BitSet<Local> {
188175
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
189-
for p in replacements.fields.keys() {
190-
all_dead_locals.insert(p.local);
176+
for (local, replacements) in replacements.fragments.iter_enumerated() {
177+
if replacements.is_some() {
178+
all_dead_locals.insert(local);
179+
}
191180
}
192181
debug!(?all_dead_locals);
193182
if all_dead_locals.is_empty() {
@@ -197,7 +186,7 @@ fn replace_flattened_locals<'tcx>(
197186
let mut visitor = ReplacementVisitor {
198187
tcx,
199188
local_decls: &body.local_decls,
200-
replacements,
189+
replacements: &replacements,
201190
all_dead_locals,
202191
patch: MirPatch::new(body),
203192
};
@@ -223,21 +212,23 @@ struct ReplacementVisitor<'tcx, 'll> {
223212
/// This is only used to compute the type for `VarDebugInfoContents::Composite`.
224213
local_decls: &'ll LocalDecls<'tcx>,
225214
/// Work to do.
226-
replacements: ReplacementMap<'tcx>,
215+
replacements: &'ll ReplacementMap<'tcx>,
227216
/// This is used to check that we are not leaving references to replaced locals behind.
228217
all_dead_locals: BitSet<Local>,
229218
patch: MirPatch<'tcx>,
230219
}
231220

232-
impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
233-
fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
234-
if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
235-
let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
236-
let local = self.replacements.fields.get(&pr)?;
237-
Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
238-
} else {
239-
None
221+
impl<'tcx> ReplacementVisitor<'tcx, '_> {
222+
fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
223+
let mut fragments = Vec::new();
224+
let parts = self.replacements.place_fragments(local.into())?;
225+
for (field, ty, replacement_local) in parts {
226+
fragments.push(VarDebugInfoFragment {
227+
projection: vec![PlaceElem::Field(field, ty)],
228+
contents: Place::from(replacement_local),
229+
});
240230
}
231+
Some(fragments)
241232
}
242233
}
243234

@@ -246,21 +237,30 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
246237
self.tcx
247238
}
248239

240+
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
241+
if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
242+
*place = repl
243+
} else {
244+
self.super_place(place, context, location)
245+
}
246+
}
247+
249248
#[instrument(level = "trace", skip(self))]
250249
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
251250
match statement.kind {
251+
// Duplicate storage and deinit statements, as they pretty much apply to all fields.
252252
StatementKind::StorageLive(l) => {
253-
if let Some(final_locals) = &self.replacements.fragments[l] {
254-
for &(_, fl) in final_locals {
253+
if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
254+
for (_, _, fl) in final_locals {
255255
self.patch.add_statement(location, StatementKind::StorageLive(fl));
256256
}
257257
statement.make_nop();
258258
}
259259
return;
260260
}
261261
StatementKind::StorageDead(l) => {
262-
if let Some(final_locals) = &self.replacements.fragments[l] {
263-
for &(_, fl) in final_locals {
262+
if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
263+
for (_, _, fl) in final_locals {
264264
self.patch.add_statement(location, StatementKind::StorageDead(fl));
265265
}
266266
statement.make_nop();
@@ -269,7 +269,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
269269
}
270270
StatementKind::Deinit(box place) => {
271271
if let Some(final_locals) = self.replacements.place_fragments(place) {
272-
for &(_, fl) in final_locals {
272+
for (_, _, fl) in final_locals {
273273
self.patch
274274
.add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
275275
}
@@ -278,48 +278,80 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
278278
}
279279
}
280280

281-
StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref operands))) => {
282-
if let Some(final_locals) = self.replacements.place_fragments(place) {
283-
for &(projection, fl) in final_locals {
284-
let &[PlaceElem::Field(index, _)] = projection else { bug!() };
285-
let index = index.as_usize();
286-
let rvalue = Rvalue::Use(operands[index].clone());
287-
self.patch.add_statement(
288-
location,
289-
StatementKind::Assign(Box::new((fl.into(), rvalue))),
290-
);
281+
// We have `a = Struct { 0: x, 1: y, .. }`.
282+
// We replace it by
283+
// ```
284+
// a_0 = x
285+
// a_1 = y
286+
// ...
287+
// ```
288+
StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
289+
if let Some(local) = place.as_local()
290+
&& let Some(final_locals) = &self.replacements.fragments[local]
291+
{
292+
// This is ok as we delete the statement later.
293+
let operands = std::mem::take(operands);
294+
for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
295+
if let Some((_, new_local)) = opt_ty_local {
296+
// Replace mentions of SROA'd locals that appear in the operand.
297+
self.visit_operand(&mut operand, location);
298+
299+
let rvalue = Rvalue::Use(operand);
300+
self.patch.add_statement(
301+
location,
302+
StatementKind::Assign(Box::new((new_local.into(), rvalue))),
303+
);
304+
}
291305
}
292306
statement.make_nop();
293307
return;
294308
}
295309
}
296310

311+
// We have `a = some constant`
312+
// We add the projections.
313+
// ```
314+
// a_0 = a.0
315+
// a_1 = a.1
316+
// ...
317+
// ```
318+
// ConstProp will pick up the pieces and replace them by actual constants.
297319
StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
298320
if let Some(final_locals) = self.replacements.place_fragments(place) {
299-
for &(projection, fl) in final_locals {
300-
let rvalue =
301-
Rvalue::Use(Operand::Move(place.project_deeper(projection, self.tcx)));
321+
for (field, ty, new_local) in final_locals {
322+
let rplace = self.tcx.mk_place_field(place, field, ty);
323+
let rvalue = Rvalue::Use(Operand::Move(rplace));
302324
self.patch.add_statement(
303325
location,
304-
StatementKind::Assign(Box::new((fl.into(), rvalue))),
326+
StatementKind::Assign(Box::new((new_local.into(), rvalue))),
305327
);
306328
}
307-
self.all_dead_locals.remove(place.local);
329+
// We still need `place.local` to exist, so don't make it nop.
308330
return;
309331
}
310332
}
311333

334+
// We have `a = move? place`
335+
// We replace it by
336+
// ```
337+
// a_0 = move? place.0
338+
// a_1 = move? place.1
339+
// ...
340+
// ```
312341
StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
313-
let (rplace, copy) = match op {
342+
let (rplace, copy) = match *op {
314343
Operand::Copy(rplace) => (rplace, true),
315344
Operand::Move(rplace) => (rplace, false),
316345
Operand::Constant(_) => bug!(),
317346
};
318347
if let Some(final_locals) = self.replacements.place_fragments(lhs) {
319-
for &(projection, fl) in final_locals {
320-
let rplace = rplace.project_deeper(projection, self.tcx);
348+
for (field, ty, new_local) in final_locals {
349+
let rplace = self.tcx.mk_place_field(rplace, field, ty);
321350
debug!(?rplace);
322-
let rplace = self.replace_place(rplace.as_ref()).unwrap_or(rplace);
351+
let rplace = self
352+
.replacements
353+
.replace_place(self.tcx, rplace.as_ref())
354+
.unwrap_or(rplace);
323355
debug!(?rplace);
324356
let rvalue = if copy {
325357
Rvalue::Use(Operand::Copy(rplace))
@@ -328,7 +360,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
328360
};
329361
self.patch.add_statement(
330362
location,
331-
StatementKind::Assign(Box::new((fl.into(), rvalue))),
363+
StatementKind::Assign(Box::new((new_local.into(), rvalue))),
332364
);
333365
}
334366
statement.make_nop();
@@ -341,22 +373,14 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
341373
self.super_statement(statement, location)
342374
}
343375

344-
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
345-
if let Some(repl) = self.replace_place(place.as_ref()) {
346-
*place = repl
347-
} else {
348-
self.super_place(place, context, location)
349-
}
350-
}
351-
352376
#[instrument(level = "trace", skip(self))]
353377
fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
354378
match &mut var_debug_info.value {
355379
VarDebugInfoContents::Place(ref mut place) => {
356-
if let Some(repl) = self.replace_place(place.as_ref()) {
380+
if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
357381
*place = repl;
358-
} else if let Some(fragments) =
359-
self.replacements.gather_debug_info_fragments(place.as_ref())
382+
} else if let Some(local) = place.as_local()
383+
&& let Some(fragments) = self.gather_debug_info_fragments(local)
360384
{
361385
let ty = place.ty(self.local_decls, self.tcx).ty;
362386
var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
@@ -367,12 +391,13 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
367391
debug!(?fragments);
368392
fragments
369393
.drain_filter(|fragment| {
370-
if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
394+
if let Some(repl) =
395+
self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
396+
{
371397
fragment.contents = repl;
372398
false
373-
} else if let Some(frg) = self
374-
.replacements
375-
.gather_debug_info_fragments(fragment.contents.as_ref())
399+
} else if let Some(local) = fragment.contents.as_local()
400+
&& let Some(frg) = self.gather_debug_info_fragments(local)
376401
{
377402
new_fragments.extend(frg.into_iter().map(|mut f| {
378403
f.projection.splice(0..0, fragment.projection.iter().copied());

0 commit comments

Comments
 (0)