|
12 | 12 | use crate::transform::{simplify, MirPass, MirSource};
|
13 | 13 | use itertools::Itertools as _;
|
14 | 14 | use rustc_index::vec::IndexVec;
|
| 15 | +use rustc_middle::mir::visit::{PlaceContext, Visitor}; |
15 | 16 | use rustc_middle::mir::*;
|
16 | 17 | use rustc_middle::ty::{Ty, TyCtxt};
|
17 | 18 | use rustc_target::abi::VariantIdx;
|
@@ -75,7 +76,9 @@ struct ArmIdentityInfo<'tcx> {
|
75 | 76 | stmts_to_remove: Vec<usize>,
|
76 | 77 | }
|
77 | 78 |
|
78 |
| -fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmIdentityInfo<'tcx>> { |
| 79 | +fn get_arm_identity_info<'a, 'tcx>( |
| 80 | + stmts: &'a [Statement<'tcx>], |
| 81 | +) -> Option<ArmIdentityInfo<'tcx>> { |
79 | 82 | // This can't possibly match unless there are at least 3 statements in the block
|
80 | 83 | // so fail fast on tiny blocks.
|
81 | 84 | if stmts.len() < 3 {
|
@@ -249,6 +252,7 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
|
249 | 252 | fn optimization_applies<'tcx>(
|
250 | 253 | opt_info: &ArmIdentityInfo<'tcx>,
|
251 | 254 | local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
|
| 255 | + local_uses: &IndexVec<Local, usize>, |
252 | 256 | ) -> bool {
|
253 | 257 | trace!("testing if optimization applies...");
|
254 | 258 |
|
@@ -285,6 +289,26 @@ fn optimization_applies<'tcx>(
|
285 | 289 | last_assigned_to = *l;
|
286 | 290 | }
|
287 | 291 |
|
| 292 | + // Check that the first and last used locals are only used twice |
| 293 | + // since they are of the form: |
| 294 | + // |
| 295 | + // ``` |
| 296 | + // _first = ((_x as Variant).n: ty); |
| 297 | + // _n = _first; |
| 298 | + // ... |
| 299 | + // ((_y as Variant).n: ty) = _n; |
| 300 | + // discriminant(_y) = z; |
| 301 | + // ``` |
| 302 | + for (l, r) in &opt_info.field_tmp_assignments { |
| 303 | + if local_uses[*l] != 2 { |
| 304 | + warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); |
| 305 | + return false; |
| 306 | + } else if local_uses[*r] != 2 { |
| 307 | + warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); |
| 308 | + return false; |
| 309 | + } |
| 310 | + } |
| 311 | + |
288 | 312 | if source_local != opt_info.local_temp_0 {
|
289 | 313 | trace!(
|
290 | 314 | "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
|
@@ -312,11 +336,12 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
|
312 | 336 | }
|
313 | 337 |
|
314 | 338 | trace!("running SimplifyArmIdentity on {:?}", source);
|
| 339 | + let local_uses = LocalUseCounter::get_local_uses(body); |
315 | 340 | let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
|
316 | 341 | for bb in basic_blocks {
|
317 | 342 | if let Some(opt_info) = get_arm_identity_info(&bb.statements) {
|
318 | 343 | trace!("got opt_info = {:#?}", opt_info);
|
319 |
| - if !optimization_applies(&opt_info, local_decls) { |
| 344 | + if !optimization_applies(&opt_info, local_decls, &local_uses) { |
320 | 345 | debug!("optimization skipped for {:?}", source);
|
321 | 346 | continue;
|
322 | 347 | }
|
@@ -358,6 +383,28 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
|
358 | 383 | }
|
359 | 384 | }
|
360 | 385 |
|
| 386 | +struct LocalUseCounter { |
| 387 | + local_uses: IndexVec<Local, usize>, |
| 388 | +} |
| 389 | + |
| 390 | +impl LocalUseCounter { |
| 391 | + fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> { |
| 392 | + let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; |
| 393 | + counter.visit_body(body); |
| 394 | + counter.local_uses |
| 395 | + } |
| 396 | +} |
| 397 | + |
| 398 | +impl<'tcx> Visitor<'tcx> for LocalUseCounter { |
| 399 | + fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) { |
| 400 | + if context.is_storage_marker() { |
| 401 | + return; |
| 402 | + } |
| 403 | + |
| 404 | + self.local_uses[*local] += 1; |
| 405 | + } |
| 406 | +} |
| 407 | + |
361 | 408 | /// Match on:
|
362 | 409 | /// ```rust
|
363 | 410 | /// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
|
|
0 commit comments