Skip to content

Commit 7a0b78d

Browse files
committed
Reapply "Auto merge of rust-lang#129047 - DianQK:early_otherwise_branch_scalar, r=cjgillot"
This reverts commit 16a0266.
1 parent 3378a5e commit 7a0b78d

5 files changed

+422
-85
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

Lines changed: 175 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
129129

130130
let mut patch = MirPatch::new(body);
131131

132-
// create temp to store second discriminant in, `_s` in example above
133-
let second_discriminant_temp =
134-
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
132+
let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
133+
// create temp to store second discriminant in, `_s` in example above
134+
let second_discriminant_temp =
135+
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
135136

136-
patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
137+
patch.add_statement(
138+
parent_end,
139+
StatementKind::StorageLive(second_discriminant_temp),
140+
);
137141

138-
// create assignment of discriminant
139-
patch.add_assign(
140-
parent_end,
141-
Place::from(second_discriminant_temp),
142-
Rvalue::Discriminant(opt_data.child_place),
143-
);
142+
// create assignment of discriminant
143+
patch.add_assign(
144+
parent_end,
145+
Place::from(second_discriminant_temp),
146+
Rvalue::Discriminant(opt_data.child_place),
147+
);
148+
(
149+
Some(second_discriminant_temp),
150+
Operand::Move(Place::from(second_discriminant_temp)),
151+
)
152+
} else {
153+
(None, Operand::Copy(opt_data.child_place))
154+
};
144155

145156
// create temp to store inequality comparison between the two discriminants, `_t` in
146157
// example above
@@ -149,11 +160,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
149160
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
150161
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
151162

152-
// create inequality comparison between the two discriminants
153-
let comp_rvalue = Rvalue::BinaryOp(
154-
nequal,
155-
Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
156-
);
163+
// create inequality comparison
164+
let comp_rvalue =
165+
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
157166
patch.add_statement(
158167
parent_end,
159168
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
@@ -189,8 +198,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
189198
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
190199
);
191200

192-
// generate StorageDead for the second_discriminant_temp not in use anymore
193-
patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
201+
if let Some(second_discriminant_temp) = second_discriminant_temp {
202+
// generate StorageDead for the second_discriminant_temp not in use anymore
203+
patch.add_statement(
204+
parent_end,
205+
StatementKind::StorageDead(second_discriminant_temp),
206+
);
207+
}
194208

195209
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
196210
// the switch
@@ -218,6 +232,7 @@ struct OptimizationData<'tcx> {
218232
child_place: Place<'tcx>,
219233
child_ty: Ty<'tcx>,
220234
child_source: SourceInfo,
235+
need_hoist_discriminant: bool,
221236
}
222237

223238
fn evaluate_candidate<'tcx>(
@@ -231,70 +246,128 @@ fn evaluate_candidate<'tcx>(
231246
return None;
232247
};
233248
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
234-
if !bbs[targets.otherwise()].is_empty_unreachable() {
235-
// Someone could write code like this:
236-
// ```rust
237-
// let Q = val;
238-
// if discriminant(P) == otherwise {
239-
// let ptr = &mut Q as *mut _ as *mut u8;
240-
// // It may be difficult for us to effectively determine whether values are valid.
241-
// // Invalid values can come from all sorts of corners.
242-
// unsafe { *ptr = 10; }
243-
// }
244-
//
245-
// match P {
246-
// A => match Q {
247-
// A => {
248-
// // code
249-
// }
250-
// _ => {
251-
// // don't use Q
252-
// }
253-
// }
254-
// _ => {
255-
// // don't use Q
256-
// }
257-
// };
258-
// ```
259-
//
260-
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
261-
// of an invalid value, which is UB.
262-
// In order to fix this, **we would either need to show that the discriminant computation of
263-
// `place` is computed in all branches**.
264-
// FIXME(#95162) For the moment, we adopt a conservative approach and
265-
// consider only the `otherwise` branch has no statements and an unreachable terminator.
266-
return None;
267-
}
268249
let (_, child) = targets.iter().next()?;
269-
let child_terminator = &bbs[child].terminator();
270-
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
271-
&child_terminator.kind
250+
251+
let Terminator {
252+
kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
253+
source_info,
254+
} = bbs[child].terminator()
272255
else {
273256
return None;
274257
};
275258
let child_ty = child_discr.ty(body.local_decls(), tcx);
276259
if child_ty != parent_ty {
277260
return None;
278261
}
279-
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {
262+
263+
// We only handle:
264+
// ```
265+
// bb4: {
266+
// _8 = discriminant((_3.1: Enum1));
267+
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
268+
// }
269+
// ```
270+
// and
271+
// ```
272+
// bb2: {
273+
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
274+
// }
275+
// ```
276+
if bbs[child].statements.len() > 1 {
280277
return None;
278+
}
279+
280+
// When thie BB has exactly one statement, this statement should be discriminant.
281+
let need_hoist_discriminant = bbs[child].statements.len() == 1;
282+
let child_place = if need_hoist_discriminant {
283+
if !bbs[targets.otherwise()].is_empty_unreachable() {
284+
// Someone could write code like this:
285+
// ```rust
286+
// let Q = val;
287+
// if discriminant(P) == otherwise {
288+
// let ptr = &mut Q as *mut _ as *mut u8;
289+
// // It may be difficult for us to effectively determine whether values are valid.
290+
// // Invalid values can come from all sorts of corners.
291+
// unsafe { *ptr = 10; }
292+
// }
293+
//
294+
// match P {
295+
// A => match Q {
296+
// A => {
297+
// // code
298+
// }
299+
// _ => {
300+
// // don't use Q
301+
// }
302+
// }
303+
// _ => {
304+
// // don't use Q
305+
// }
306+
// };
307+
// ```
308+
//
309+
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
310+
// invalid value, which is UB.
311+
// In order to fix this, **we would either need to show that the discriminant computation of
312+
// `place` is computed in all branches**.
313+
// FIXME(#95162) For the moment, we adopt a conservative approach and
314+
// consider only the `otherwise` branch has no statements and an unreachable terminator.
315+
return None;
316+
}
317+
// Handle:
318+
// ```
319+
// bb4: {
320+
// _8 = discriminant((_3.1: Enum1));
321+
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
322+
// }
323+
// ```
324+
let [
325+
Statement {
326+
kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
327+
..
328+
},
329+
] = bbs[child].statements.as_slice()
330+
else {
331+
return None;
332+
};
333+
*child_place
334+
} else {
335+
// Handle:
336+
// ```
337+
// bb2: {
338+
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
339+
// }
340+
// ```
341+
let Operand::Copy(child_place) = child_discr else {
342+
return None;
343+
};
344+
*child_place
281345
};
282-
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
283-
return None;
346+
let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
347+
{
348+
child_targets.otherwise()
349+
} else {
350+
targets.otherwise()
284351
};
285-
let destination = child_targets.otherwise();
286352

287353
// Verify that the optimization is legal for each branch
288354
for (value, child) in targets.iter() {
289-
if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
355+
if !verify_candidate_branch(
356+
&bbs[child],
357+
value,
358+
child_place,
359+
destination,
360+
need_hoist_discriminant,
361+
) {
290362
return None;
291363
}
292364
}
293365
Some(OptimizationData {
294366
destination,
295-
child_place: *child_place,
367+
child_place,
296368
child_ty,
297-
child_source: child_terminator.source_info,
369+
child_source: *source_info,
370+
need_hoist_discriminant,
298371
})
299372
}
300373

@@ -303,31 +376,48 @@ fn verify_candidate_branch<'tcx>(
303376
value: u128,
304377
place: Place<'tcx>,
305378
destination: BasicBlock,
379+
need_hoist_discriminant: bool,
306380
) -> bool {
307-
// In order for the optimization to be correct, the branch must...
308-
// ...have exactly one statement
309-
if let [statement] = branch.statements.as_slice()
310-
// ...assign the discriminant of `place` in that statement
311-
&& let StatementKind::Assign(boxed) = &statement.kind
312-
&& let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed
313-
&& *from_place == place
314-
// ...make that assignment to a local
315-
&& discr_place.projection.is_empty()
316-
// ...terminate on a `SwitchInt` that invalidates that local
317-
&& let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } =
318-
&branch.terminator().kind
319-
&& *switch_op == Operand::Move(*discr_place)
320-
// ...fall through to `destination` if the switch misses
321-
&& destination == targets.otherwise()
322-
// ...have a branch for value `value`
323-
&& let mut iter = targets.iter()
324-
&& let Some((target_value, _)) = iter.next()
325-
&& target_value == value
326-
// ...and have no more branches
327-
&& iter.next().is_none()
328-
{
329-
true
381+
// In order for the optimization to be correct, the terminator must be a `SwitchInt`.
382+
let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
383+
return false;
384+
};
385+
if need_hoist_discriminant {
386+
// If we need hoist discriminant, the branch must have exactly one statement.
387+
let [statement] = branch.statements.as_slice() else {
388+
return false;
389+
};
390+
// The statement must assign the discriminant of `place`.
391+
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
392+
statement.kind
393+
else {
394+
return false;
395+
};
396+
if from_place != place {
397+
return false;
398+
}
399+
// The assignment must invalidate a local that terminate on a `SwitchInt`.
400+
if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
401+
return false;
402+
}
330403
} else {
331-
false
404+
// If we don't need hoist discriminant, the branch must not have any statements.
405+
if !branch.statements.is_empty() {
406+
return false;
407+
}
408+
// The place on `SwitchInt` must be the same.
409+
if *switch_op != Operand::Copy(place) {
410+
return false;
411+
}
332412
}
413+
// It must fall through to `destination` if the switch misses.
414+
if destination != targets.otherwise() {
415+
return false;
416+
}
417+
// It must have exactly one branch for value `value` and have no more branches.
418+
let mut iter = targets.iter();
419+
let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
420+
return false;
421+
};
422+
target_value == value
333423
}

0 commit comments

Comments
 (0)