Skip to content

Commit 6b56603

Browse files
committed
Auto merge of #80475 - simonvandel:fix-77355, r=oli-obk
New mir-opt pass to simplify gotos with const values (reopening #77486) Reopening PR #77486 Fixes #77355 This pass optimizes the following sequence ```rust bb2: { _2 = const true; goto -> bb3; } bb3: { switchInt(_2) -> [false: bb4, otherwise: bb5]; } ``` into ```rust bb2: { _2 = const true; goto -> bb5; } ```
2 parents 301ad8a + a6dccfe commit 6b56603

22 files changed

+652
-374
lines changed

compiler/rustc_middle/src/mir/mod.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -1518,7 +1518,14 @@ pub enum StatementKind<'tcx> {
15181518
}
15191519

15201520
impl<'tcx> StatementKind<'tcx> {
1521-
pub fn as_assign_mut(&mut self) -> Option<&mut Box<(Place<'tcx>, Rvalue<'tcx>)>> {
1521+
pub fn as_assign_mut(&mut self) -> Option<&mut (Place<'tcx>, Rvalue<'tcx>)> {
1522+
match self {
1523+
StatementKind::Assign(x) => Some(x),
1524+
_ => None,
1525+
}
1526+
}
1527+
1528+
pub fn as_assign(&self) -> Option<&(Place<'tcx>, Rvalue<'tcx>)> {
15221529
match self {
15231530
StatementKind::Assign(x) => Some(x),
15241531
_ => None,

compiler/rustc_middle/src/mir/terminator.rs

+16
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,22 @@ impl<'tcx> TerminatorKind<'tcx> {
407407
| TerminatorKind::FalseUnwind { ref mut unwind, .. } => Some(unwind),
408408
}
409409
}
410+
411+
pub fn as_switch(&self) -> Option<(&Operand<'tcx>, Ty<'tcx>, &SwitchTargets)> {
412+
match self {
413+
TerminatorKind::SwitchInt { discr, switch_ty, targets } => {
414+
Some((discr, switch_ty, targets))
415+
}
416+
_ => None,
417+
}
418+
}
419+
420+
pub fn as_goto(&self) -> Option<BasicBlock> {
421+
match self {
422+
TerminatorKind::Goto { target } => Some(*target),
423+
_ => None,
424+
}
425+
}
410426
}
411427

412428
impl<'tcx> Debug for TerminatorKind<'tcx> {

compiler/rustc_mir/src/interpret/operand.rs

+1
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
514514
/// Evaluate the operand, returning a place where you can then find the data.
515515
/// If you already know the layout, you can save two table lookups
516516
/// by passing it in here.
517+
#[inline]
517518
pub fn eval_operand(
518519
&self,
519520
mir_op: &mir::Operand<'tcx>,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! This pass optimizes the following sequence
2+
//! ```rust,ignore (example)
3+
//! bb2: {
4+
//! _2 = const true;
5+
//! goto -> bb3;
6+
//! }
7+
//!
8+
//! bb3: {
9+
//! switchInt(_2) -> [false: bb4, otherwise: bb5];
10+
//! }
11+
//! ```
12+
//! into
13+
//! ```rust,ignore (example)
14+
//! bb2: {
15+
//! _2 = const true;
16+
//! goto -> bb5;
17+
//! }
18+
//! ```
19+
20+
use crate::transform::MirPass;
21+
use rustc_middle::mir::*;
22+
use rustc_middle::ty::TyCtxt;
23+
use rustc_middle::{mir::visit::Visitor, ty::ParamEnv};
24+
25+
use super::simplify::{simplify_cfg, simplify_locals};
26+
27+
pub struct ConstGoto;
28+
29+
impl<'tcx> MirPass<'tcx> for ConstGoto {
30+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
31+
if tcx.sess.opts.debugging_opts.mir_opt_level < 3 {
32+
return;
33+
}
34+
trace!("Running ConstGoto on {:?}", body.source);
35+
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
36+
let mut opt_finder =
37+
ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env };
38+
opt_finder.visit_body(body);
39+
let should_simplify = !opt_finder.optimizations.is_empty();
40+
for opt in opt_finder.optimizations {
41+
let terminator = body.basic_blocks_mut()[opt.bb_with_goto].terminator_mut();
42+
let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto };
43+
debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto);
44+
terminator.kind = new_goto;
45+
}
46+
47+
// if we applied optimizations, we potentially have some cfg to cleanup to
48+
// make it easier for further passes
49+
if should_simplify {
50+
simplify_cfg(body);
51+
simplify_locals(body, tcx);
52+
}
53+
}
54+
}
55+
56+
impl<'a, 'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'a, 'tcx> {
57+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
58+
let _: Option<_> = try {
59+
let target = terminator.kind.as_goto()?;
60+
// We only apply this optimization if the last statement is a const assignment
61+
let last_statement = self.body.basic_blocks()[location.block].statements.last()?;
62+
63+
if let (place, Rvalue::Use(Operand::Constant(_const))) =
64+
last_statement.kind.as_assign()?
65+
{
66+
// We found a constant being assigned to `place`.
67+
// Now check that the target of this Goto switches on this place.
68+
let target_bb = &self.body.basic_blocks()[target];
69+
70+
// FIXME(simonvandel): We are conservative here when we don't allow
71+
// any statements in the target basic block.
72+
// This could probably be relaxed to allow `StorageDead`s which could be
73+
// copied to the predecessor of this block.
74+
if !target_bb.statements.is_empty() {
75+
None?
76+
}
77+
78+
let target_bb_terminator = target_bb.terminator();
79+
let (discr, switch_ty, targets) = target_bb_terminator.kind.as_switch()?;
80+
if discr.place() == Some(*place) {
81+
// We now know that the Switch matches on the const place, and it is statementless
82+
// Now find which value in the Switch matches the const value.
83+
let const_value =
84+
_const.literal.try_eval_bits(self.tcx, self.param_env, switch_ty)?;
85+
let found_value_idx_option = targets
86+
.iter()
87+
.enumerate()
88+
.find(|(_, (value, _))| const_value == *value)
89+
.map(|(idx, _)| idx);
90+
91+
let target_to_use_in_goto =
92+
if let Some(found_value_idx) = found_value_idx_option {
93+
targets.iter().nth(found_value_idx).unwrap().1
94+
} else {
95+
// If we did not find the const value in values, it must be the otherwise case
96+
targets.otherwise()
97+
};
98+
99+
self.optimizations.push(OptimizationToApply {
100+
bb_with_goto: location.block,
101+
target_to_use_in_goto,
102+
});
103+
}
104+
}
105+
Some(())
106+
};
107+
108+
self.super_terminator(terminator, location);
109+
}
110+
}
111+
112+
struct OptimizationToApply {
113+
bb_with_goto: BasicBlock,
114+
target_to_use_in_goto: BasicBlock,
115+
}
116+
117+
pub struct ConstGotoOptimizationFinder<'a, 'tcx> {
118+
tcx: TyCtxt<'tcx>,
119+
body: &'a Body<'tcx>,
120+
param_env: ParamEnv<'tcx>,
121+
optimizations: Vec<OptimizationToApply>,
122+
}

compiler/rustc_mir/src/transform/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod check_packed_ref;
2222
pub mod check_unsafety;
2323
pub mod cleanup_post_borrowck;
2424
pub mod const_debuginfo;
25+
pub mod const_goto;
2526
pub mod const_prop;
2627
pub mod coverage;
2728
pub mod deaggregator;
@@ -492,6 +493,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
492493

493494
// The main optimizations that we do on MIR.
494495
let optimizations: &[&dyn MirPass<'tcx>] = &[
496+
&const_goto::ConstGoto,
495497
&remove_unneeded_drops::RemoveUnneededDrops,
496498
&match_branches::MatchBranchSimplification,
497499
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)

compiler/rustc_mir/src/transform/simplify.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -320,28 +320,31 @@ pub struct SimplifyLocals;
320320
impl<'tcx> MirPass<'tcx> for SimplifyLocals {
321321
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
322322
trace!("running SimplifyLocals on {:?}", body.source);
323+
simplify_locals(body, tcx);
324+
}
325+
}
323326

324-
// First, we're going to get a count of *actual* uses for every `Local`.
325-
let mut used_locals = UsedLocals::new(body);
327+
pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) {
328+
// First, we're going to get a count of *actual* uses for every `Local`.
329+
let mut used_locals = UsedLocals::new(body);
326330

327-
// Next, we're going to remove any `Local` with zero actual uses. When we remove those
328-
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
329-
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
330-
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
331-
// fixedpoint where there are no more unused locals.
332-
remove_unused_definitions(&mut used_locals, body);
331+
// Next, we're going to remove any `Local` with zero actual uses. When we remove those
332+
// `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
333+
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
334+
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
335+
// fixedpoint where there are no more unused locals.
336+
remove_unused_definitions(&mut used_locals, body);
333337

334-
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
335-
let map = make_local_map(&mut body.local_decls, &used_locals);
338+
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
339+
let map = make_local_map(&mut body.local_decls, &used_locals);
336340

337-
// Only bother running the `LocalUpdater` if we actually found locals to remove.
338-
if map.iter().any(Option::is_none) {
339-
// Update references to all vars and tmps now
340-
let mut updater = LocalUpdater { map, tcx };
341-
updater.visit_body(body);
341+
// Only bother running the `LocalUpdater` if we actually found locals to remove.
342+
if map.iter().any(Option::is_none) {
343+
// Update references to all vars and tmps now
344+
let mut updater = LocalUpdater { map, tcx };
345+
updater.visit_body(body);
342346

343-
body.local_decls.shrink_to_fit();
344-
}
347+
body.local_decls.shrink_to_fit();
345348
}
346349
}
347350

compiler/rustc_mir/src/transform/simplify_comparison_integral.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral {
8080
// we convert the move in the comparison statement to a copy.
8181

8282
// unwrap is safe as we know this statement is an assign
83-
let box (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
83+
let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
8484

8585
use Operand::*;
8686
match rhs {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
- // MIR for `issue_77355_opt` before ConstGoto
2+
+ // MIR for `issue_77355_opt` after ConstGoto
3+
4+
fn issue_77355_opt(_1: Foo) -> u64 {
5+
debug num => _1; // in scope 0 at $DIR/const_goto.rs:11:20: 11:23
6+
let mut _0: u64; // return place in scope 0 at $DIR/const_goto.rs:11:33: 11:36
7+
- let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
8+
- let mut _3: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28
9+
+ let mut _2: isize; // in scope 0 at $DIR/const_goto.rs:12:22: 12:28
10+
11+
bb0: {
12+
- StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
13+
- _3 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
14+
- switchInt(move _3) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
15+
+ _2 = discriminant(_1); // scope 0 at $DIR/const_goto.rs:12:22: 12:28
16+
+ switchInt(move _2) -> [1_isize: bb2, 2_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto.rs:12:22: 12:28
17+
}
18+
19+
bb1: {
20+
- _2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
21+
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
22+
+ _0 = const 42_u64; // scope 0 at $DIR/const_goto.rs:12:53: 12:55
23+
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
24+
}
25+
26+
bb2: {
27+
- _2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
28+
- goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
29+
- }
30+
-
31+
- bb3: {
32+
- switchInt(move _2) -> [false: bb5, otherwise: bb4]; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
33+
- }
34+
-
35+
- bb4: {
36+
_0 = const 23_u64; // scope 0 at $DIR/const_goto.rs:12:41: 12:43
37+
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
38+
+ goto -> bb3; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
39+
}
40+
41+
- bb5: {
42+
- _0 = const 42_u64; // scope 0 at $DIR/const_goto.rs:12:53: 12:55
43+
- goto -> bb6; // scope 0 at $DIR/const_goto.rs:12:5: 12:57
44+
- }
45+
-
46+
- bb6: {
47+
- StorageDead(_2); // scope 0 at $DIR/const_goto.rs:12:56: 12:57
48+
+ bb3: {
49+
return; // scope 0 at $DIR/const_goto.rs:13:2: 13:2
50+
}
51+
}
52+

src/test/mir-opt/const_goto.rs

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pub enum Foo {
2+
A,
3+
B,
4+
C,
5+
D,
6+
E,
7+
F,
8+
}
9+
10+
// EMIT_MIR const_goto.issue_77355_opt.ConstGoto.diff
11+
fn issue_77355_opt(num: Foo) -> u64 {
12+
if matches!(num, Foo::B | Foo::C) { 23 } else { 42 }
13+
}
14+
fn main() {
15+
issue_77355_opt(Foo::A);
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
- // MIR for `f` before ConstGoto
2+
+ // MIR for `f` after ConstGoto
3+
4+
fn f() -> u64 {
5+
let mut _0: u64; // return place in scope 0 at $DIR/const_goto_const_eval_fail.rs:6:44: 6:47
6+
let mut _1: bool; // in scope 0 at $DIR/const_goto_const_eval_fail.rs:7:11: 12:6
7+
let mut _2: i32; // in scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
8+
9+
bb0: {
10+
StorageLive(_1); // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:11: 12:6
11+
StorageLive(_2); // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
12+
_2 = const A; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:15: 8:16
13+
switchInt(_2) -> [1_i32: bb2, 2_i32: bb2, 3_i32: bb2, otherwise: bb1]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:9:13: 9:14
14+
}
15+
16+
bb1: {
17+
_1 = const true; // scope 0 at $DIR/const_goto_const_eval_fail.rs:10:18: 10:22
18+
goto -> bb3; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:9: 11:10
19+
}
20+
21+
bb2: {
22+
_1 = const B; // scope 0 at $DIR/const_goto_const_eval_fail.rs:9:26: 9:27
23+
- goto -> bb3; // scope 0 at $DIR/const_goto_const_eval_fail.rs:8:9: 11:10
24+
+ switchInt(_1) -> [false: bb4, otherwise: bb3]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:9: 13:14
25+
}
26+
27+
bb3: {
28+
- switchInt(_1) -> [false: bb5, otherwise: bb4]; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:9: 13:14
29+
- }
30+
-
31+
- bb4: {
32+
_0 = const 2_u64; // scope 0 at $DIR/const_goto_const_eval_fail.rs:14:17: 14:18
33+
- goto -> bb6; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
34+
+ goto -> bb5; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
35+
}
36+
37+
- bb5: {
38+
+ bb4: {
39+
_0 = const 1_u64; // scope 0 at $DIR/const_goto_const_eval_fail.rs:13:18: 13:19
40+
- goto -> bb6; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
41+
+ goto -> bb5; // scope 0 at $DIR/const_goto_const_eval_fail.rs:7:5: 15:6
42+
}
43+
44+
- bb6: {
45+
+ bb5: {
46+
StorageDead(_2); // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:1: 16:2
47+
StorageDead(_1); // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:1: 16:2
48+
return; // scope 0 at $DIR/const_goto_const_eval_fail.rs:16:2: 16:2
49+
}
50+
}
51+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#![feature(min_const_generics)]
2+
#![crate_type = "lib"]
3+
4+
// If const eval fails, then don't crash
5+
// EMIT_MIR const_goto_const_eval_fail.f.ConstGoto.diff
6+
pub fn f<const A: i32, const B: bool>() -> u64 {
7+
match {
8+
match A {
9+
1 | 2 | 3 => B,
10+
_ => true,
11+
}
12+
} {
13+
false => 1,
14+
true => 2,
15+
}
16+
}

0 commit comments

Comments
 (0)