Skip to content

Commit cb008a0

Browse files
committed
Resolve unsound hoisting of discriminant in EarlyOtherwiseBranch
1 parent 41c79ba commit cb008a0

12 files changed

+185
-158
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+47-57
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use super::simplify::simplify_cfg;
1111
/// let y: Option<()>;
1212
/// match (x,y) {
1313
/// (Some(_), Some(_)) => {0},
14+
/// (None, None) => {2},
1415
/// _ => {1}
1516
/// }
1617
/// ```
@@ -23,10 +24,10 @@ use super::simplify::simplify_cfg;
2324
/// if discriminant_x == discriminant_y {
2425
/// match x {
2526
/// Some(_) => 0,
26-
/// _ => 1, // <----
27-
/// } // | Actually the same bb
28-
/// } else { // |
29-
/// 1 // <--------------
27+
/// None => 2,
28+
/// }
29+
/// } else {
30+
/// 1
3031
/// }
3132
/// ```
3233
///
@@ -47,18 +48,18 @@ use super::simplify::simplify_cfg;
4748
/// | | |
4849
/// ================= | | |
4950
/// | BBU | <-| | | ============================
50-
/// |---------------| | \-------> | BBD |
51-
/// |---------------| | | |--------------------------|
52-
/// | unreachable | | | | _dl = discriminant(P) |
53-
/// ================= | | |--------------------------|
54-
/// | | | switchInt(_dl) |
55-
/// ================= | | | d | ---> BBD.2
51+
/// |---------------| \-------> | BBD |
52+
/// |---------------| | |--------------------------|
53+
/// | unreachable | | | _dl = discriminant(P) |
54+
/// ================= | |--------------------------|
55+
/// | | switchInt(_dl) |
56+
/// ================= | | d | ---> BBD.2
5657
/// | BB9 | <--------------- | otherwise |
5758
/// |---------------| ============================
5859
/// | ... |
5960
/// =================
6061
/// ```
61-
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
62+
/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
6263
/// code:
6364
/// - `BB1` is `parent` and `BBC, BBD` are children
6465
/// - `P` is `child_place`
@@ -78,7 +79,7 @@ use super::simplify::simplify_cfg;
7879
/// |---------------------| | | switchInt(Q) |
7980
/// | switchInt(_t) | | | c | ---> BBC.2
8081
/// | false | --------/ | d | ---> BBD.2
81-
/// | otherwise | ---------------- | otherwise |
82+
/// | otherwise | /--------- | otherwise |
8283
/// ======================= | ============================
8384
/// |
8485
/// ================= |
@@ -219,37 +220,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
219220

220221
/// Returns true if computing the discriminant of `place` may be hoisted out of the branch
221222
fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
222-
// FIXME(JakobDegen): This is unsound. Someone could write code like this:
223-
// ```rust
224-
// let Q = val;
225-
// if discriminant(P) == otherwise {
226-
// let ptr = &mut Q as *mut _ as *mut u8;
227-
// unsafe { *ptr = 10; } // Any invalid value for the type
228-
// }
229-
//
230-
// match P {
231-
// A => match Q {
232-
// A => {
233-
// // code
234-
// }
235-
// _ => {
236-
// // don't use Q
237-
// }
238-
// }
239-
// _ => {
240-
// // don't use Q
241-
// }
242-
// };
243-
// ```
244-
//
245-
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
246-
// invalid value, which is UB.
247-
//
248-
// In order to fix this, we would either need to show that the discriminant computation of
249-
// `place` is computed in all branches, including the `otherwise` branch, or we would need
250-
// another analysis pass to determine that the place is fully initialized. It might even be best
251-
// to have the hoisting be performed in a different pass and just do the CFG changing in this
252-
// pass.
253223
for (place, proj) in place.iter_projections() {
254224
match proj {
255225
// Dereferencing in the computation of `place` might cause issues from one of two
@@ -315,18 +285,38 @@ fn evaluate_candidate<'tcx>(
315285
return None;
316286
};
317287
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
318-
let parent_dest = {
319-
let poss = targets.otherwise();
320-
// If the fallthrough on the parent is trivially unreachable, we can let the
321-
// children choose the destination
322-
if bbs[poss].statements.len() == 0
323-
&& bbs[poss].terminator().kind == TerminatorKind::Unreachable
324-
{
325-
None
326-
} else {
327-
Some(poss)
328-
}
329-
};
288+
if !bbs[targets.otherwise()].is_empty_unreachable() {
289+
// Someone could write code like this:
290+
// ```rust
291+
// let Q = val;
292+
// if discriminant(P) == otherwise {
293+
// let ptr = &mut Q as *mut _ as *mut u8;
294+
// // Any invalid value for the type. It is possible to be opaque, such as in other functions.
295+
// unsafe { *ptr = 10; }
296+
// }
297+
//
298+
// match P {
299+
// A => match Q {
300+
// A => {
301+
// // code
302+
// }
303+
// _ => {
304+
// // don't use Q
305+
// }
306+
// }
307+
// _ => {
308+
// // don't use Q
309+
// }
310+
// };
311+
// ```
312+
//
313+
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
314+
// invalid value, which is UB.
315+
// In order to fix this, we would either need to show that the discriminant computation of
316+
// `place` is computed in all branches.
317+
// So we need the `otherwise` branch has no statements and an unreachable terminator.
318+
return None;
319+
}
330320
let (_, child) = targets.iter().next()?;
331321
let child_terminator = &bbs[child].terminator();
332322
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
@@ -344,7 +334,7 @@ fn evaluate_candidate<'tcx>(
344334
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
345335
return None;
346336
};
347-
let destination = parent_dest.unwrap_or(child_targets.otherwise());
337+
let destination = child_targets.otherwise();
348338

349339
// Verify that the optimization is legal in general
350340
// We can hoist evaluating the child discriminant out of the branch
@@ -411,5 +401,5 @@ fn verify_candidate_branch<'tcx>(
411401
if let Some(_) = iter.next() {
412402
return false;
413403
}
414-
return true;
404+
true
415405
}

tests/mir-opt/early_otherwise_branch.opt1.EarlyOtherwiseBranch.diff

+9-26
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
let mut _7: isize;
1313
let _8: u32;
1414
let _9: u32;
15-
+ let mut _10: isize;
16-
+ let mut _11: bool;
1715
scope 1 {
1816
debug a => _8;
1917
debug b => _9;
@@ -29,48 +27,33 @@
2927
StorageDead(_5);
3028
StorageDead(_4);
3129
_7 = discriminant((_3.0: std::option::Option<u32>));
32-
- switchInt(move _7) -> [1: bb2, otherwise: bb1];
33-
+ StorageLive(_10);
34-
+ _10 = discriminant((_3.1: std::option::Option<u32>));
35-
+ StorageLive(_11);
36-
+ _11 = Ne(_7, move _10);
37-
+ StorageDead(_10);
38-
+ switchInt(move _11) -> [0: bb4, otherwise: bb1];
30+
switchInt(move _7) -> [1: bb2, otherwise: bb1];
3931
}
4032

4133
bb1: {
42-
+ StorageDead(_11);
4334
_0 = const 1_u32;
44-
- goto -> bb4;
45-
+ goto -> bb3;
35+
goto -> bb4;
4636
}
4737

4838
bb2: {
49-
- _6 = discriminant((_3.1: std::option::Option<u32>));
50-
- switchInt(move _6) -> [1: bb3, otherwise: bb1];
51-
- }
52-
-
53-
- bb3: {
39+
_6 = discriminant((_3.1: std::option::Option<u32>));
40+
switchInt(move _6) -> [1: bb3, otherwise: bb1];
41+
}
42+
43+
bb3: {
5444
StorageLive(_9);
5545
_9 = (((_3.1: std::option::Option<u32>) as Some).0: u32);
5646
StorageLive(_8);
5747
_8 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
5848
_0 = const 0_u32;
5949
StorageDead(_8);
6050
StorageDead(_9);
61-
- goto -> bb4;
62-
+ goto -> bb3;
51+
goto -> bb4;
6352
}
6453

65-
- bb4: {
66-
+ bb3: {
54+
bb4: {
6755
StorageDead(_3);
6856
return;
69-
+ }
70-
+
71-
+ bb4: {
72-
+ StorageDead(_11);
73-
+ switchInt(_7) -> [1: bb2, otherwise: bb1];
7457
}
7558
}
7659

tests/mir-opt/early_otherwise_branch.opt2.EarlyOtherwiseBranch.diff

+6-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
StorageDead(_5);
3131
StorageDead(_4);
3232
_8 = discriminant((_3.0: std::option::Option<u32>));
33-
- switchInt(move _8) -> [0: bb2, 1: bb3, otherwise: bb1];
33+
- switchInt(move _8) -> [0: bb2, 1: bb3, otherwise: bb7];
3434
+ StorageLive(_11);
3535
+ _11 = discriminant((_3.1: std::option::Option<u32>));
3636
+ StorageLive(_12);
@@ -70,7 +70,7 @@
7070

7171
- bb5: {
7272
+ bb3: {
73-
_0 = const 0_u32;
73+
_0 = const 2_u32;
7474
- goto -> bb6;
7575
+ goto -> bb4;
7676
}
@@ -79,8 +79,10 @@
7979
+ bb4: {
8080
StorageDead(_3);
8181
return;
82-
+ }
83-
+
82+
}
83+
84+
- bb7: {
85+
- unreachable;
8486
+ bb5: {
8587
+ StorageDead(_12);
8688
+ switchInt(_8) -> [0: bb3, 1: bb2, otherwise: bb1];

tests/mir-opt/early_otherwise_branch.opt3.EarlyOtherwiseBranch.diff

+44-29
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
let mut _5: std::option::Option<bool>;
1111
let mut _6: isize;
1212
let mut _7: isize;
13-
let _8: u32;
14-
let _9: bool;
15-
+ let mut _10: isize;
16-
+ let mut _11: bool;
13+
let mut _8: isize;
14+
let _9: u32;
15+
let _10: bool;
16+
+ let mut _11: isize;
17+
+ let mut _12: bool;
1718
scope 1 {
18-
debug a => _8;
19-
debug b => _9;
19+
debug a => _9;
20+
debug b => _10;
2021
}
2122

2223
bb0: {
@@ -28,49 +29,63 @@
2829
_3 = (move _4, move _5);
2930
StorageDead(_5);
3031
StorageDead(_4);
31-
_7 = discriminant((_3.0: std::option::Option<u32>));
32-
- switchInt(move _7) -> [1: bb2, otherwise: bb1];
33-
+ StorageLive(_10);
34-
+ _10 = discriminant((_3.1: std::option::Option<bool>));
32+
_8 = discriminant((_3.0: std::option::Option<u32>));
33+
- switchInt(move _8) -> [0: bb2, 1: bb3, otherwise: bb7];
3534
+ StorageLive(_11);
36-
+ _11 = Ne(_7, move _10);
37-
+ StorageDead(_10);
38-
+ switchInt(move _11) -> [0: bb4, otherwise: bb1];
35+
+ _11 = discriminant((_3.1: std::option::Option<bool>));
36+
+ StorageLive(_12);
37+
+ _12 = Ne(_8, move _11);
38+
+ StorageDead(_11);
39+
+ switchInt(move _12) -> [0: bb5, otherwise: bb1];
3940
}
4041

4142
bb1: {
42-
+ StorageDead(_11);
43+
+ StorageDead(_12);
4344
_0 = const 1_u32;
44-
- goto -> bb4;
45-
+ goto -> bb3;
45+
- goto -> bb6;
46+
+ goto -> bb4;
4647
}
4748

4849
bb2: {
4950
- _6 = discriminant((_3.1: std::option::Option<bool>));
50-
- switchInt(move _6) -> [1: bb3, otherwise: bb1];
51+
- switchInt(move _6) -> [0: bb5, otherwise: bb1];
5152
- }
5253
-
5354
- bb3: {
55+
- _7 = discriminant((_3.1: std::option::Option<bool>));
56+
- switchInt(move _7) -> [1: bb4, otherwise: bb1];
57+
- }
58+
-
59+
- bb4: {
60+
StorageLive(_10);
61+
_10 = (((_3.1: std::option::Option<bool>) as Some).0: bool);
5462
StorageLive(_9);
55-
_9 = (((_3.1: std::option::Option<bool>) as Some).0: bool);
56-
StorageLive(_8);
57-
_8 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
63+
_9 = (((_3.0: std::option::Option<u32>) as Some).0: u32);
5864
_0 = const 0_u32;
59-
StorageDead(_8);
6065
StorageDead(_9);
61-
- goto -> bb4;
62-
+ goto -> bb3;
66+
StorageDead(_10);
67+
- goto -> bb6;
68+
+ goto -> bb4;
6369
}
6470

65-
- bb4: {
71+
- bb5: {
6672
+ bb3: {
73+
_0 = const 2_u32;
74+
- goto -> bb6;
75+
+ goto -> bb4;
76+
}
77+
78+
- bb6: {
79+
+ bb4: {
6780
StorageDead(_3);
6881
return;
69-
+ }
70-
+
71-
+ bb4: {
72-
+ StorageDead(_11);
73-
+ switchInt(_7) -> [1: bb2, otherwise: bb1];
82+
}
83+
84+
- bb7: {
85+
- unreachable;
86+
+ bb5: {
87+
+ StorageDead(_12);
88+
+ switchInt(_8) -> [0: bb3, 1: bb2, otherwise: bb1];
7489
}
7590
}
7691

tests/mir-opt/early_otherwise_branch.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// skip-filecheck
22
//@ unit-test: EarlyOtherwiseBranch
3+
//@ compile-flags: -Zmir-enable-passes=+UninhabitedEnumBranching
4+
5+
// We can't optimize it because y may be an invalid value.
36
// EMIT_MIR early_otherwise_branch.opt1.EarlyOtherwiseBranch.diff
47
fn opt1(x: Option<u32>, y: Option<u32>) -> u32 {
58
match (x, y) {
@@ -12,7 +15,7 @@ fn opt1(x: Option<u32>, y: Option<u32>) -> u32 {
1215
fn opt2(x: Option<u32>, y: Option<u32>) -> u32 {
1316
match (x, y) {
1417
(Some(a), Some(b)) => 0,
15-
(None, None) => 0,
18+
(None, None) => 2,
1619
_ => 1,
1720
}
1821
}
@@ -22,6 +25,7 @@ fn opt2(x: Option<u32>, y: Option<u32>) -> u32 {
2225
fn opt3(x: Option<u32>, y: Option<bool>) -> u32 {
2326
match (x, y) {
2427
(Some(a), Some(b)) => 0,
28+
(None, None) => 2,
2529
_ => 1,
2630
}
2731
}

0 commit comments

Comments
 (0)