Skip to content

Commit eccc782

Browse files
committed
Transforms match into an assignment statement
1 parent beadbcb commit eccc782

9 files changed

+364
-115
lines changed

compiler/rustc_middle/src/mir/terminator.rs

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ impl SwitchTargets {
7474
pub fn target_for_value(&self, value: u128) -> BasicBlock {
7575
self.iter().find_map(|(v, t)| (v == value).then_some(t)).unwrap_or_else(|| self.otherwise())
7676
}
77+
78+
/// Returns true if all targets (including the fallback target) are distinct.
79+
#[inline]
80+
pub fn is_distinct(&self) -> bool {
81+
self.targets.iter().collect::<FxHashSet<_>>().len() == self.targets.len()
82+
}
7783
}
7884

7985
pub struct SwitchTargetsIter<'a> {

compiler/rustc_mir_transform/src/match_branches.rs

+217-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rustc_index::IndexVec;
22
use rustc_middle::mir::*;
3-
use rustc_middle::ty::{ParamEnv, Ty, TyCtxt};
3+
use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
44
use std::iter;
55

66
use super::simplify::simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838
should_cleanup = true;
3939
continue;
4040
}
41+
if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env)
42+
{
43+
should_cleanup = true;
44+
continue;
45+
}
4146
}
4247

4348
if should_cleanup {
@@ -48,7 +53,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4853

4954
trait SimplifyMatch<'tcx> {
5055
fn simplify(
51-
&self,
56+
&mut self,
5257
tcx: TyCtxt<'tcx>,
5358
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
5459
bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
@@ -72,7 +77,7 @@ trait SimplifyMatch<'tcx> {
7277
let source_info = bbs[switch_bb_idx].terminator().source_info;
7378
let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span));
7479

75-
// We already checked that first and second are different blocks,
80+
// We already checked that targets are different blocks,
7681
// and bb_idx has a different terminator from both of them.
7782
let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty);
7883
let (_, first) = targets.iter().next().unwrap();
@@ -91,7 +96,7 @@ trait SimplifyMatch<'tcx> {
9196
}
9297

9398
fn can_simplify(
94-
&self,
99+
&mut self,
95100
tcx: TyCtxt<'tcx>,
96101
targets: &SwitchTargets,
97102
param_env: ParamEnv<'tcx>,
@@ -144,7 +149,7 @@ struct SimplifyToIf;
144149
/// ```
145150
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
146151
fn can_simplify(
147-
&self,
152+
&mut self,
148153
tcx: TyCtxt<'tcx>,
149154
targets: &SwitchTargets,
150155
param_env: ParamEnv<'tcx>,
@@ -250,3 +255,210 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250255
new_stmts.collect()
251256
}
252257
}
258+
259+
#[derive(Default)]
260+
struct SimplifyToExp {
261+
transfrom_types: Vec<TransfromType>,
262+
}
263+
264+
#[derive(Clone, Copy)]
265+
enum CompareType<'tcx, 'a> {
266+
Same(&'a StatementKind<'tcx>),
267+
Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
268+
Discr(&'a Place<'tcx>, Ty<'tcx>),
269+
}
270+
271+
enum TransfromType {
272+
Same,
273+
Eq,
274+
Discr,
275+
}
276+
277+
impl From<CompareType<'_, '_>> for TransfromType {
278+
fn from(compare_type: CompareType<'_, '_>) -> Self {
279+
match compare_type {
280+
CompareType::Same(_) => TransfromType::Same,
281+
CompareType::Eq(_, _, _) => TransfromType::Eq,
282+
CompareType::Discr(_, _) => TransfromType::Discr,
283+
}
284+
}
285+
}
286+
287+
/// If we find that the value of match is the same as the assignment,
288+
/// merge a target block statements into the source block,
289+
/// using cast to transform different integer types.
290+
///
291+
/// For example:
292+
///
293+
/// ```ignore (MIR)
294+
/// bb0: {
295+
/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
296+
/// }
297+
///
298+
/// bb1: {
299+
/// unreachable;
300+
/// }
301+
///
302+
/// bb2: {
303+
/// _0 = const 1_i16;
304+
/// goto -> bb5;
305+
/// }
306+
///
307+
/// bb3: {
308+
/// _0 = const 2_i16;
309+
/// goto -> bb5;
310+
/// }
311+
///
312+
/// bb4: {
313+
/// _0 = const 3_i16;
314+
/// goto -> bb5;
315+
/// }
316+
/// ```
317+
///
318+
/// into:
319+
///
320+
/// ```ignore (MIR)
321+
/// bb0: {
322+
/// _0 = _3 as i16 (IntToInt);
323+
/// goto -> bb5;
324+
/// }
325+
/// ```
326+
impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
327+
fn can_simplify(
328+
&mut self,
329+
tcx: TyCtxt<'tcx>,
330+
targets: &SwitchTargets,
331+
param_env: ParamEnv<'tcx>,
332+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
333+
) -> bool {
334+
if targets.iter().len() < 2 || targets.iter().len() > 64 {
335+
return false;
336+
}
337+
// We require that the possible target blocks all be distinct.
338+
if !targets.is_distinct() {
339+
return false;
340+
}
341+
if !bbs[targets.otherwise()].is_empty_unreachable() {
342+
return false;
343+
}
344+
let mut iter = targets.iter();
345+
let (first_val, first_target) = iter.next().unwrap();
346+
let first_terminator_kind = &bbs[first_target].terminator().kind;
347+
// Check that destinations are identical, and if not, then don't optimize this block
348+
if !targets
349+
.iter()
350+
.all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
351+
{
352+
return false;
353+
}
354+
355+
let first_stmts = &bbs[first_target].statements;
356+
let (second_val, second_target) = iter.next().unwrap();
357+
let second_stmts = &bbs[second_target].statements;
358+
if first_stmts.len() != second_stmts.len() {
359+
return false;
360+
}
361+
362+
let mut compare_types = Vec::new();
363+
for (f, s) in iter::zip(first_stmts, second_stmts) {
364+
let compare_type = match (&f.kind, &s.kind) {
365+
// If two statements are exactly the same, we can optimize.
366+
(f_s, s_s) if f_s == s_s => CompareType::Same(f_s),
367+
368+
// If two statements are assignments with the match values to the same place, we can optimize.
369+
(
370+
StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
371+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
372+
) if lhs_f == lhs_s
373+
&& f_c.const_.ty() == s_c.const_.ty()
374+
&& f_c.const_.ty().is_integral() =>
375+
{
376+
match (
377+
f_c.const_.try_eval_scalar_int(tcx, param_env),
378+
s_c.const_.try_eval_scalar_int(tcx, param_env),
379+
) {
380+
(Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
381+
(Some(f), Some(s))
382+
if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
383+
&& Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
384+
{
385+
CompareType::Discr(lhs_f, f_c.const_.ty())
386+
}
387+
_ => return false,
388+
}
389+
}
390+
391+
// Otherwise we cannot optimize. Try another block.
392+
_ => return false,
393+
};
394+
compare_types.push(compare_type);
395+
}
396+
397+
for (other_val, other_target) in iter {
398+
let other_stmts = &bbs[other_target].statements;
399+
if compare_types.len() != other_stmts.len() {
400+
return false;
401+
}
402+
for (f, s) in iter::zip(&compare_types, other_stmts) {
403+
match (*f, &s.kind) {
404+
(CompareType::Same(f_s), s_s) if f_s == s_s => {}
405+
(
406+
CompareType::Eq(lhs_f, f_ty, val),
407+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
408+
) if lhs_f == lhs_s
409+
&& s_c.const_.ty() == f_ty
410+
&& s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
411+
(
412+
CompareType::Discr(lhs_f, f_ty),
413+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
414+
) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
415+
let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
416+
return false;
417+
};
418+
if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
419+
return false;
420+
}
421+
}
422+
_ => return false,
423+
}
424+
}
425+
}
426+
self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect();
427+
true
428+
}
429+
430+
fn new_stmts(
431+
&self,
432+
_tcx: TyCtxt<'tcx>,
433+
targets: &SwitchTargets,
434+
_param_env: ParamEnv<'tcx>,
435+
bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
436+
discr_local: Local,
437+
discr_ty: Ty<'tcx>,
438+
) -> Vec<Statement<'tcx>> {
439+
let (_, first) = targets.iter().next().unwrap();
440+
let first = &bbs[first];
441+
442+
let new_stmts =
443+
iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) {
444+
(TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(),
445+
(
446+
TransfromType::Discr,
447+
StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
448+
) => {
449+
let operand = Operand::Copy(Place::from(discr_local));
450+
let r_val = if f_c.const_.ty() == discr_ty {
451+
Rvalue::Use(operand)
452+
} else {
453+
Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
454+
};
455+
Statement {
456+
source_info: s.source_info,
457+
kind: StatementKind::Assign(Box::new((*lhs, r_val))),
458+
}
459+
}
460+
_ => unreachable!(),
461+
});
462+
new_stmts.collect()
463+
}
464+
}

tests/codegen/match-optimized.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ pub fn exhaustive_match(e: E) -> u8 {
2626
// CHECK-NEXT: store i8 1, ptr %_0, align 1
2727
// CHECK-NEXT: br label %[[EXIT]]
2828
// CHECK: [[C]]:
29-
// CHECK-NEXT: store i8 2, ptr %_0, align 1
29+
// CHECK-NEXT: store i8 3, ptr %_0, align 1
3030
// CHECK-NEXT: br label %[[EXIT]]
3131
match e {
3232
E::A => 0,
3333
E::B => 1,
34-
E::C => 2,
34+
E::C => 3,
3535
}
3636
}
3737

tests/mir-opt/matches_reduce_branches.match_i128_u128.MatchBranchSimplification.diff

+33-28
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,42 @@
55
debug i => _1;
66
let mut _0: u128;
77
let mut _2: i128;
8+
+ let mut _3: i128;
89

910
bb0: {
1011
_2 = discriminant(_1);
11-
switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2];
12-
}
13-
14-
bb1: {
15-
_0 = const _;
16-
goto -> bb6;
17-
}
18-
19-
bb2: {
20-
unreachable;
21-
}
22-
23-
bb3: {
24-
_0 = const 1_u128;
25-
goto -> bb6;
26-
}
27-
28-
bb4: {
29-
_0 = const 2_u128;
30-
goto -> bb6;
31-
}
32-
33-
bb5: {
34-
_0 = const 3_u128;
35-
goto -> bb6;
36-
}
37-
38-
bb6: {
12+
- switchInt(move _2) -> [1: bb3, 2: bb4, 3: bb5, 340282366920938463463374607431768211455: bb1, otherwise: bb2];
13+
- }
14+
-
15+
- bb1: {
16+
- _0 = const _;
17+
- goto -> bb6;
18+
- }
19+
-
20+
- bb2: {
21+
- unreachable;
22+
- }
23+
-
24+
- bb3: {
25+
- _0 = const 1_u128;
26+
- goto -> bb6;
27+
- }
28+
-
29+
- bb4: {
30+
- _0 = const 2_u128;
31+
- goto -> bb6;
32+
- }
33+
-
34+
- bb5: {
35+
- _0 = const 3_u128;
36+
- goto -> bb6;
37+
- }
38+
-
39+
- bb6: {
40+
+ StorageLive(_3);
41+
+ _3 = move _2;
42+
+ _0 = _3 as u128 (IntToInt);
43+
+ StorageDead(_3);
3944
return;
4045
}
4146
}

0 commit comments

Comments
 (0)