Skip to content

Commit 8b56099

Browse files
committed
Apply EarlyOtherwiseBranch to scalar values
1 parent f65946e commit 8b56099

4 files changed

+295
-74
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+124-74
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,29 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
135135

136136
let mut patch = MirPatch::new(body);
137137

138-
// create temp to store second discriminant in, `_s` in example above
139-
let second_discriminant_temp =
140-
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
138+
let (second_discriminant_temp, second_operand) = if opt_data.hoist_discriminant {
139+
// create temp to store second discriminant in, `_s` in example above
140+
let second_discriminant_temp =
141+
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
141142

142-
patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
143+
patch.add_statement(
144+
parent_end,
145+
StatementKind::StorageLive(second_discriminant_temp),
146+
);
143147

144-
// create assignment of discriminant
145-
patch.add_assign(
146-
parent_end,
147-
Place::from(second_discriminant_temp),
148-
Rvalue::Discriminant(opt_data.child_place),
149-
);
148+
// create assignment of discriminant
149+
patch.add_assign(
150+
parent_end,
151+
Place::from(second_discriminant_temp),
152+
Rvalue::Discriminant(opt_data.child_place),
153+
);
154+
(
155+
Some(second_discriminant_temp),
156+
Operand::Move(Place::from(second_discriminant_temp)),
157+
)
158+
} else {
159+
(None, Operand::Copy(opt_data.child_place))
160+
};
150161

151162
// create temp to store inequality comparison between the two discriminants, `_t` in
152163
// example above
@@ -156,10 +167,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
156167
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
157168

158169
// create inequality comparison between the two discriminants
159-
let comp_rvalue = Rvalue::BinaryOp(
160-
nequal,
161-
Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
162-
);
170+
let comp_rvalue =
171+
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
163172
patch.add_statement(
164173
parent_end,
165174
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
@@ -194,8 +203,13 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
194203
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
195204
);
196205

197-
// generate StorageDead for the second_discriminant_temp not in use anymore
198-
patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
206+
if let Some(second_discriminant_temp) = second_discriminant_temp {
207+
// generate StorageDead for the second_discriminant_temp not in use anymore
208+
patch.add_statement(
209+
parent_end,
210+
StatementKind::StorageDead(second_discriminant_temp),
211+
);
212+
}
199213

200214
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
201215
// the switch
@@ -271,6 +285,7 @@ struct OptimizationData<'tcx> {
271285
child_place: Place<'tcx>,
272286
child_ty: Ty<'tcx>,
273287
child_source: SourceInfo,
288+
hoist_discriminant: bool,
274289
}
275290

276291
fn evaluate_candidate<'tcx>(
@@ -284,38 +299,6 @@ fn evaluate_candidate<'tcx>(
284299
return None;
285300
};
286301
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
287-
if !bbs[targets.otherwise()].is_empty_unreachable() {
288-
// Someone could write code like this:
289-
// ```rust
290-
// let Q = val;
291-
// if discriminant(P) == otherwise {
292-
// let ptr = &mut Q as *mut _ as *mut u8;
293-
// // Any invalid value for the type. It is possible to be opaque, such as in other functions.
294-
// unsafe { *ptr = 10; }
295-
// }
296-
//
297-
// match P {
298-
// A => match Q {
299-
// A => {
300-
// // code
301-
// }
302-
// _ => {
303-
// // don't use Q
304-
// }
305-
// }
306-
// _ => {
307-
// // don't use Q
308-
// }
309-
// };
310-
// ```
311-
//
312-
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
313-
// invalid value, which is UB.
314-
// In order to fix this, we would either need to show that the discriminant computation of
315-
// `place` is computed in all branches.
316-
// So we need the `otherwise` branch has no statements and an unreachable terminator.
317-
return None;
318-
}
319302
let (_, child) = targets.iter().next()?;
320303
let child_terminator = &bbs[child].terminator();
321304
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
@@ -327,31 +310,89 @@ fn evaluate_candidate<'tcx>(
327310
if child_ty != parent_ty {
328311
return None;
329312
}
330-
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {
313+
if bbs[child].statements.len() > 1 {
331314
return None;
315+
}
316+
let hoist_discriminant = bbs[child].statements.len() == 1;
317+
let child_place = if hoist_discriminant {
318+
if !bbs[targets.otherwise()].is_empty_unreachable() {
319+
// Someone could write code like this:
320+
// ```rust
321+
// let Q = val;
322+
// if discriminant(P) == otherwise {
323+
// let ptr = &mut Q as *mut _ as *mut u8;
324+
// // Any invalid value for the type. It is possible to be opaque, such as in other functions.
325+
// unsafe { *ptr = 10; }
326+
// }
327+
//
328+
// match P {
329+
// A => match Q {
330+
// A => {
331+
// // code
332+
// }
333+
// _ => {
334+
// // don't use Q
335+
// }
336+
// }
337+
// _ => {
338+
// // don't use Q
339+
// }
340+
// };
341+
// ```
342+
//
343+
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
344+
// invalid value, which is UB.
345+
// In order to fix this, we would either need to show that the discriminant computation of
346+
// `place` is computed in all branches.
347+
// So we need the `otherwise` branch has no statements and an unreachable terminator.
348+
return None;
349+
}
350+
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind)
351+
else {
352+
return None;
353+
};
354+
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
355+
return None;
356+
};
357+
// Verify that the optimization is legal in general
358+
// We can hoist evaluating the child discriminant out of the branch
359+
if !may_hoist(tcx, body, *child_place) {
360+
return None;
361+
}
362+
*child_place
363+
} else {
364+
let TerminatorKind::SwitchInt { discr, .. } = &bbs[child].terminator().kind else {
365+
return None;
366+
};
367+
let Operand::Copy(child_place) = discr else {
368+
return None;
369+
};
370+
*child_place
332371
};
333-
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
334-
return None;
372+
let destination = if hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() {
373+
child_targets.otherwise()
374+
} else {
375+
targets.otherwise()
335376
};
336-
let destination = child_targets.otherwise();
337-
338-
// Verify that the optimization is legal in general
339-
// We can hoist evaluating the child discriminant out of the branch
340-
if !may_hoist(tcx, body, *child_place) {
341-
return None;
342-
}
343377

344378
// Verify that the optimization is legal for each branch
345379
for (value, child) in targets.iter() {
346-
if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
380+
if !verify_candidate_branch(
381+
&bbs[child],
382+
value,
383+
child_place,
384+
destination,
385+
hoist_discriminant,
386+
) {
347387
return None;
348388
}
349389
}
350390
Some(OptimizationData {
351391
destination,
352-
child_place: *child_place,
392+
child_place,
353393
child_ty,
354394
child_source: child_terminator.source_info,
395+
hoist_discriminant,
355396
})
356397
}
357398

@@ -360,29 +401,38 @@ fn verify_candidate_branch<'tcx>(
360401
value: u128,
361402
place: Place<'tcx>,
362403
destination: BasicBlock,
404+
hoist_discriminant: bool,
363405
) -> bool {
364406
// In order for the optimization to be correct, the branch must...
365407
// ...have exactly one statement
366-
if branch.statements.len() != 1 {
367-
return false;
368-
}
369-
// ...assign the discriminant of `place` in that statement
370-
let StatementKind::Assign(boxed) = &branch.statements[0].kind else { return false };
371-
let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else { return false };
372-
if *from_place != place {
373-
return false;
374-
}
375-
// ...make that assignment to a local
376-
if discr_place.projection.len() != 0 {
408+
if (hoist_discriminant && branch.statements.len() != 1)
409+
|| (!hoist_discriminant && !branch.statements.is_empty())
410+
{
377411
return false;
378412
}
379413
// ...terminate on a `SwitchInt` that invalidates that local
380414
let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } = &branch.terminator().kind
381415
else {
382416
return false;
383417
};
384-
if *switch_op != Operand::Move(*discr_place) {
385-
return false;
418+
if hoist_discriminant {
419+
// ...assign the discriminant of `place` in that statement
420+
let StatementKind::Assign(boxed) = &branch.statements[0].kind else { return false };
421+
let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else { return false };
422+
if *from_place != place {
423+
return false;
424+
}
425+
// ...make that assignment to a local
426+
if discr_place.projection.len() != 0 {
427+
return false;
428+
}
429+
if *switch_op != Operand::Move(*discr_place) {
430+
return false;
431+
}
432+
} else {
433+
if *switch_op != Operand::Copy(place) {
434+
return false;
435+
}
386436
}
387437
// ...fall through to `destination` if the switch misses
388438
if destination != targets.otherwise() {
@@ -397,7 +447,7 @@ fn verify_candidate_branch<'tcx>(
397447
return false;
398448
}
399449
// ...and have no more branches
400-
if let Some(_) = iter.next() {
450+
if iter.next().is_some() {
401451
return false;
402452
}
403453
true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
- // MIR for `opt4` before EarlyOtherwiseBranch
2+
+ // MIR for `opt4` after EarlyOtherwiseBranch
3+
4+
fn opt4(_1: u32, _2: u32) -> u32 {
5+
debug x => _1;
6+
debug y => _2;
7+
let mut _0: u32;
8+
let mut _3: (u32, u32);
9+
let mut _4: u32;
10+
let mut _5: u32;
11+
+ let mut _6: bool;
12+
13+
bb0: {
14+
StorageLive(_3);
15+
StorageLive(_4);
16+
_4 = _1;
17+
StorageLive(_5);
18+
_5 = _2;
19+
_3 = (move _4, move _5);
20+
StorageDead(_5);
21+
StorageDead(_4);
22+
- switchInt((_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
23+
+ StorageLive(_6);
24+
+ _6 = Ne((_3.0: u32), (_3.1: u32));
25+
+ switchInt(move _6) -> [0: bb6, otherwise: bb1];
26+
}
27+
28+
bb1: {
29+
+ StorageDead(_6);
30+
_0 = const 0_u32;
31+
- goto -> bb8;
32+
+ goto -> bb5;
33+
}
34+
35+
bb2: {
36+
- switchInt((_3.1: u32)) -> [1: bb5, otherwise: bb1];
37+
+ _0 = const 4_u32;
38+
+ goto -> bb5;
39+
}
40+
41+
bb3: {
42+
- switchInt((_3.1: u32)) -> [2: bb6, otherwise: bb1];
43+
+ _0 = const 5_u32;
44+
+ goto -> bb5;
45+
}
46+
47+
bb4: {
48+
- switchInt((_3.1: u32)) -> [3: bb7, otherwise: bb1];
49+
+ _0 = const 6_u32;
50+
+ goto -> bb5;
51+
}
52+
53+
bb5: {
54+
- _0 = const 4_u32;
55+
- goto -> bb8;
56+
+ StorageDead(_3);
57+
+ return;
58+
}
59+
60+
bb6: {
61+
- _0 = const 5_u32;
62+
- goto -> bb8;
63+
- }
64+
-
65+
- bb7: {
66+
- _0 = const 6_u32;
67+
- goto -> bb8;
68+
- }
69+
-
70+
- bb8: {
71+
- StorageDead(_3);
72+
- return;
73+
+ StorageDead(_6);
74+
+ switchInt((_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
75+
}
76+
}
77+

0 commit comments

Comments
 (0)