Skip to content

Commit 6658b16

Browse files
committed
Apply EarlyOtherwiseBranch for all targets with the same value
1 parent 8b56099 commit 6658b16

4 files changed

+287
-37
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+87-37
Original file line numberDiff line numberDiff line change
@@ -159,27 +159,12 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
159159
(None, Operand::Copy(opt_data.child_place))
160160
};
161161

162-
// create temp to store inequality comparison between the two discriminants, `_t` in
163-
// example above
164-
let nequal = BinOp::Ne;
165-
let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
166-
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
167-
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
168-
169-
// create inequality comparison between the two discriminants
170-
let comp_rvalue =
171-
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
172-
patch.add_statement(
173-
parent_end,
174-
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
175-
);
176-
177162
let eq_new_targets = parent_targets.iter().map(|(value, child)| {
178163
let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
179164
else {
180165
unreachable!()
181166
};
182-
(value, targets.target_for_value(value))
167+
(value, targets.all_targets()[0])
183168
});
184169
let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
185170

@@ -188,36 +173,77 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
188173
source_info: bbs[parent].terminator().source_info,
189174
kind: TerminatorKind::SwitchInt {
190175
// switch on the first discriminant, so we can mark the second one as dead
191-
discr: parent_op,
176+
discr: parent_op.clone(),
192177
targets: eq_targets,
193178
},
194179
}));
195180

196181
let eq_bb = patch.new_block(eq_switch);
197182

198-
// Jump to it on the basis of the inequality comparison
199-
let true_case = opt_data.destination;
200-
let false_case = eq_bb;
201-
patch.patch_terminator(
202-
parent,
203-
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
204-
);
183+
if let Some(same_target_value) = opt_data.same_target_value {
184+
let t = TerminatorKind::SwitchInt {
185+
discr: second_operand,
186+
targets: SwitchTargets::static_if(
187+
same_target_value,
188+
eq_bb,
189+
opt_data.destination,
190+
),
191+
};
192+
patch.patch_terminator(parent, t);
193+
194+
if let Some(second_discriminant_temp) = second_discriminant_temp {
195+
// Generate a StorageDead for second_discriminant_temp in each of the targets, since we moved it into
196+
// the switch
197+
for bb in [eq_bb, opt_data.destination].iter() {
198+
patch.add_statement(
199+
Location { block: *bb, statement_index: 0 },
200+
StatementKind::StorageDead(second_discriminant_temp),
201+
);
202+
}
203+
}
204+
} else {
205+
// create temp to store inequality comparison between the two discriminants, `_t` in
206+
// example above
207+
let nequal = BinOp::Ne;
208+
let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
209+
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
210+
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
205211

206-
if let Some(second_discriminant_temp) = second_discriminant_temp {
207-
// generate StorageDead for the second_discriminant_temp not in use anymore
212+
// create inequality comparison between the two discriminants
213+
let comp_rvalue = Rvalue::BinaryOp(nequal, Box::new((parent_op, second_operand)));
208214
patch.add_statement(
209215
parent_end,
210-
StatementKind::StorageDead(second_discriminant_temp),
216+
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
211217
);
212-
}
213218

214-
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
215-
// the switch
216-
for bb in [false_case, true_case].iter() {
217-
patch.add_statement(
218-
Location { block: *bb, statement_index: 0 },
219-
StatementKind::StorageDead(comp_temp),
219+
// Jump to it on the basis of the inequality comparison
220+
let true_case = opt_data.destination;
221+
let false_case = eq_bb;
222+
patch.patch_terminator(
223+
parent,
224+
TerminatorKind::if_(
225+
Operand::Move(Place::from(comp_temp)),
226+
true_case,
227+
false_case,
228+
),
220229
);
230+
231+
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
232+
// the switch
233+
for bb in [false_case, true_case].iter() {
234+
patch.add_statement(
235+
Location { block: *bb, statement_index: 0 },
236+
StatementKind::StorageDead(comp_temp),
237+
);
238+
}
239+
240+
if let Some(second_discriminant_temp) = second_discriminant_temp {
241+
// generate StorageDead for the second_discriminant_temp not in use anymore
242+
patch.add_statement(
243+
parent_end,
244+
StatementKind::StorageDead(second_discriminant_temp),
245+
);
246+
}
221247
}
222248

223249
patch.apply(body);
@@ -286,6 +312,7 @@ struct OptimizationData<'tcx> {
286312
child_ty: Ty<'tcx>,
287313
child_source: SourceInfo,
288314
hoist_discriminant: bool,
315+
same_target_value: Option<u128>,
289316
}
290317

291318
fn evaluate_candidate<'tcx>(
@@ -375,16 +402,38 @@ fn evaluate_candidate<'tcx>(
375402
targets.otherwise()
376403
};
377404

405+
let TerminatorKind::SwitchInt { targets: child_targets, .. } = &bbs[child].terminator().kind
406+
else {
407+
return None;
408+
};
378409
// Verify that the optimization is legal for each branch
379-
for (value, child) in targets.iter() {
410+
let Some((may_same_target_value, _)) = child_targets.iter().next() else {
411+
return None;
412+
};
413+
let mut same_target_value = Some(may_same_target_value);
414+
for (_, child) in targets.iter() {
380415
if !verify_candidate_branch(
381416
&bbs[child],
382-
value,
417+
may_same_target_value,
383418
child_place,
384419
destination,
385420
hoist_discriminant,
386421
) {
387-
return None;
422+
same_target_value = None;
423+
break;
424+
}
425+
}
426+
if same_target_value.is_none() {
427+
for (value, child) in targets.iter() {
428+
if !verify_candidate_branch(
429+
&bbs[child],
430+
value,
431+
child_place,
432+
destination,
433+
hoist_discriminant,
434+
) {
435+
return None;
436+
}
388437
}
389438
}
390439
Some(OptimizationData {
@@ -393,6 +442,7 @@ fn evaluate_candidate<'tcx>(
393442
child_ty,
394443
child_source: child_terminator.source_info,
395444
hoist_discriminant,
445+
same_target_value,
396446
})
397447
}
398448

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
- // MIR for `opt6` before EarlyOtherwiseBranch
2+
+ // MIR for `opt6` after EarlyOtherwiseBranch
3+
4+
fn opt6(_1: u32, _2: u32) -> u32 {
5+
debug x => _1;
6+
debug y => _2;
7+
let mut _0: u32;
8+
let mut _3: (u32, u32);
9+
let mut _4: u32;
10+
let mut _5: u32;
11+
12+
bb0: {
13+
StorageLive(_3);
14+
StorageLive(_4);
15+
_4 = _1;
16+
StorageLive(_5);
17+
_5 = _2;
18+
_3 = (move _4, move _5);
19+
StorageDead(_5);
20+
StorageDead(_4);
21+
- switchInt((_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
22+
+ switchInt((_3.1: u32)) -> [10: bb6, otherwise: bb1];
23+
}
24+
25+
bb1: {
26+
_0 = const 0_u32;
27+
- goto -> bb8;
28+
+ goto -> bb5;
29+
}
30+
31+
bb2: {
32+
- switchInt((_3.1: u32)) -> [10: bb5, otherwise: bb1];
33+
+ _0 = const 4_u32;
34+
+ goto -> bb5;
35+
}
36+
37+
bb3: {
38+
- switchInt((_3.1: u32)) -> [10: bb6, otherwise: bb1];
39+
+ _0 = const 5_u32;
40+
+ goto -> bb5;
41+
}
42+
43+
bb4: {
44+
- switchInt((_3.1: u32)) -> [10: bb7, otherwise: bb1];
45+
+ _0 = const 6_u32;
46+
+ goto -> bb5;
47+
}
48+
49+
bb5: {
50+
- _0 = const 4_u32;
51+
- goto -> bb8;
52+
+ StorageDead(_3);
53+
+ return;
54+
}
55+
56+
bb6: {
57+
- _0 = const 5_u32;
58+
- goto -> bb8;
59+
- }
60+
-
61+
- bb7: {
62+
- _0 = const 6_u32;
63+
- goto -> bb8;
64+
- }
65+
-
66+
- bb8: {
67+
- StorageDead(_3);
68+
- return;
69+
+ switchInt((_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
70+
}
71+
}
72+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
- // MIR for `opt7` before EarlyOtherwiseBranch
2+
+ // MIR for `opt7` after EarlyOtherwiseBranch
3+
4+
fn opt7(_1: Option<u32>, _2: Option<u32>) -> u32 {
5+
debug x => _1;
6+
debug y => _2;
7+
let mut _0: u32;
8+
let mut _3: (std::option::Option<u32>, std::option::Option<u32>);
9+
let mut _4: std::option::Option<u32>;
10+
let mut _5: std::option::Option<u32>;
11+
let mut _6: isize;
12+
let mut _7: isize;
13+
let mut _8: isize;
14+
let _9: u32;
15+
let _10: u32;
16+
let _11: u32;
17+
+ let mut _12: isize;
18+
scope 1 {
19+
debug a => _9;
20+
debug b => _10;
21+
}
22+
scope 2 {
23+
debug b => _11;
24+
}
25+
26+
bb0: {
27+
StorageLive(_3);
28+
StorageLive(_4);
29+
_4 = _1;
30+
StorageLive(_5);
31+
_5 = _2;
32+
_3 = (move _4, move _5);
33+
StorageDead(_5);
34+
StorageDead(_4);
35+
_8 = discriminant((_3.0: std::option::Option<u32>));
36+
- switchInt(move _8) -> [0: bb2, 1: bb3, otherwise: bb7];
37+
+ StorageLive(_12);
38+
+ _12 = discriminant((_3.1: std::option::Option<u32>));
39+
+ switchInt(move _12) -> [1: bb5, otherwise: bb1];
40+
}
41+
42+
bb1: {
43+
+ StorageDead(_12);
44+
_0 = const 1_u32;
45+
- goto -> bb6;
46+
+ goto -> bb4;
47+
}
48+
49+
bb2: {
50+
- _6 = discriminant((_3.1: std::option::Option<u32>));
51+
- switchInt(move _6) -> [1: bb5, otherwise: bb1];
52+
- }
53+
-
54+
- bb3: {
55+
- _7 = discriminant((_3.1: std::option::Option<u32>));
56+
- switchInt(move _7) -> [1: bb4, otherwise: bb1];
57+
- }
58+
-
59+
- bb4: {
60+
StorageLive(_10);
61+
_10 = (((_3.1: std::option::Option<u32>) as Some).0: u32);
62+
StorageLive(_9);
63+
_9 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
64+
_0 = const 0_u32;
65+
StorageDead(_9);
66+
StorageDead(_10);
67+
- goto -> bb6;
68+
+ goto -> bb4;
69+
}
70+
71+
- bb5: {
72+
+ bb3: {
73+
StorageLive(_11);
74+
_11 = (((_3.1: std::option::Option<u32>) as Some).0: u32);
75+
_0 = const 2_u32;
76+
StorageDead(_11);
77+
- goto -> bb6;
78+
+ goto -> bb4;
79+
}
80+
81+
- bb6: {
82+
+ bb4: {
83+
StorageDead(_3);
84+
return;
85+
}
86+
87+
- bb7: {
88+
- unreachable;
89+
+ bb5: {
90+
+ StorageDead(_12);
91+
+ switchInt(_8) -> [0: bb3, 1: bb2, otherwise: bb1];
92+
}
93+
}
94+

tests/mir-opt/early_otherwise_branch.rs

+34
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,44 @@ fn opt5(x: u32, y: u32) -> u32 {
8383
}
8484
}
8585

86+
// EMIT_MIR early_otherwise_branch.opt6.EarlyOtherwiseBranch.diff
87+
fn opt6(x: u32, y: u32) -> u32 {
88+
// CHECK-LABEL: fn opt6(
89+
// CHECK: bb0: {
90+
// CHECK: switchInt((_{{.*}}: u32)) -> [10: [[SWITCH_BB:bb.*]], otherwise: [[OTHERWISE:bb.*]]];
91+
// CHECK-NEXT: }
92+
// CHECK: [[SWITCH_BB]]:
93+
// CHECK: switchInt((_{{.*}}: u32)) -> [1: bb{{.*}}, 2: bb{{.*}}, 3: bb{{.*}}, otherwise: [[OTHERWISE]]];
94+
// CHECK-NEXT: }
95+
match (x, y) {
96+
(1, 10) => 4,
97+
(2, 10) => 5,
98+
(3, 10) => 6,
99+
_ => 0,
100+
}
101+
}
102+
103+
// EMIT_MIR early_otherwise_branch.opt7.EarlyOtherwiseBranch.diff
104+
fn opt7(x: Option<u32>, y: Option<u32>) -> u32 {
105+
// CHECK-LABEL: fn opt7(
106+
// CHECK: bb0: {
107+
// CHECK: [[LOCAL1:_.*]] = discriminant({{.*}});
108+
// CHECK: [[LOCAL2:_.*]] = discriminant({{.*}});
109+
// CHECK: switchInt(move [[LOCAL2]]) -> [
110+
// CHECK-NEXT: }
111+
match (x, y) {
112+
(Some(a), Some(b)) => 0,
113+
(None, Some(b)) => 2,
114+
_ => 1,
115+
}
116+
}
117+
86118
fn main() {
87119
opt1(None, Some(0));
88120
opt2(None, Some(0));
89121
opt3(None, Some(false));
90122
opt4(0, 0);
91123
opt5(0, 0);
124+
opt6(0, 0);
125+
opt7(None, Some(0));
92126
}

0 commit comments

Comments
 (0)