Skip to content

Commit 73ef0c0

Browse files
committed
Move some cheaper checks earlier
1 parent 834683b commit 73ef0c0

File tree

1 file changed

+54
-52
lines changed

1 file changed

+54
-52
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+54-52
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
9898
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
9999
trace!("running EarlyOtherwiseBranch on {:?}", body.source);
100100

101-
let mut should_cleanup = false;
101+
let mut should_apply_patch = false;
102+
let mut patch = MirPatch::new(body);
102103

103104
// Also consider newly generated bbs in the same pass
104105
for i in 0..body.basic_blocks.len() {
@@ -112,7 +113,7 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
112113

113114
trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
114115

115-
should_cleanup = true;
116+
should_apply_patch = true;
116117

117118
let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
118119
&bbs[parent].terminator().kind
@@ -129,8 +130,6 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
129130
let statements_before = bbs[parent].statements.len();
130131
let parent_end = Location { block: parent, statement_index: statements_before };
131132

132-
let mut patch = MirPatch::new(body);
133-
134133
let (second_discriminant_temp, second_operand) = if opt_data.hoist_discriminant {
135134
// create temp to store second discriminant in, `_s` in example above
136135
let second_discriminant_temp =
@@ -242,13 +241,12 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
242241
);
243242
}
244243
}
245-
246-
patch.apply(body);
247244
}
248245

249246
// Since this optimization adds new basic blocks and invalidates others,
250247
// clean up the cfg to make it nicer for other passes
251-
if should_cleanup {
248+
if should_apply_patch {
249+
patch.apply(body);
252250
simplify_cfg(body);
253251
}
254252
}
@@ -275,19 +273,15 @@ fn evaluate_candidate<'tcx>(
275273
return None;
276274
};
277275
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
278-
let (_, child) = targets.iter().next()?;
279-
let child_terminator = &bbs[child].terminator();
280-
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
281-
&child_terminator.kind
276+
let mut targets_iter = targets.iter();
277+
let (_, first_child) = targets_iter.next()?;
278+
let first_child_terminator = &bbs[first_child].terminator();
279+
let TerminatorKind::SwitchInt { targets: first_child_targets, discr: first_child_discr } =
280+
&first_child_terminator.kind
282281
else {
283282
return None;
284283
};
285-
let child_ty = child_discr.ty(body.local_decls(), tcx);
286-
if bbs[child].statements.len() > 1 {
287-
return None;
288-
}
289-
let hoist_discriminant = bbs[child].statements.len() == 1;
290-
let child_place = if hoist_discriminant {
284+
let hoist_discriminant = if bbs[first_child].statements.len() == 1 {
291285
if !bbs[targets.otherwise()].is_empty_unreachable() {
292286
// Someone could write code like this:
293287
// ```rust
@@ -320,7 +314,44 @@ fn evaluate_candidate<'tcx>(
320314
// So we need the `otherwise` branch has no statements and an unreachable terminator.
321315
return None;
322316
}
323-
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind)
317+
true
318+
} else if bbs[first_child].statements.is_empty() {
319+
false
320+
} else {
321+
return None;
322+
};
323+
let destination = if hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() {
324+
first_child_targets.otherwise()
325+
} else {
326+
if first_child_targets.otherwise() != targets.otherwise() {
327+
return None;
328+
}
329+
targets.otherwise()
330+
};
331+
while let Some((_, child)) = targets_iter.next() {
332+
let child_branch = &bbs[child];
333+
// In order for the optimization to be correct, the branch must...
334+
// ...have exactly one or empty statement
335+
if (hoist_discriminant && child_branch.statements.len() != 1)
336+
|| (!hoist_discriminant && !child_branch.statements.is_empty())
337+
{
338+
return None;
339+
}
340+
// ...terminate on a `SwitchInt` that invalidates that local
341+
let TerminatorKind::SwitchInt { targets: child_targets, .. } =
342+
&child_branch.terminator().kind
343+
else {
344+
return None;
345+
};
346+
if child_targets.otherwise() != destination {
347+
return None;
348+
}
349+
// Make sure there are only two branches.
350+
}
351+
let child_ty = first_child_discr.ty(body.local_decls(), tcx);
352+
let child_place = if hoist_discriminant {
353+
let Some(StatementKind::Assign(boxed)) =
354+
&bbs[first_child].statements.first().map(|x| &x.kind)
324355
else {
325356
return None;
326357
};
@@ -329,26 +360,17 @@ fn evaluate_candidate<'tcx>(
329360
};
330361
*child_place
331362
} else {
332-
let TerminatorKind::SwitchInt { discr, .. } = &bbs[child].terminator().kind else {
363+
let TerminatorKind::SwitchInt { discr, .. } = &bbs[first_child].terminator().kind else {
333364
return None;
334365
};
335366
let Operand::Copy(child_place) = discr else {
336367
return None;
337368
};
338369
*child_place
339370
};
340-
let destination = if hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() {
341-
child_targets.otherwise()
342-
} else {
343-
targets.otherwise()
344-
};
345371

346-
let TerminatorKind::SwitchInt { targets: child_targets, .. } = &bbs[child].terminator().kind
347-
else {
348-
return None;
349-
};
350372
// Verify that the optimization is legal for each branch
351-
let Some((may_same_target_value, _)) = child_targets.iter().next() else {
373+
let Some((may_same_target_value, _)) = first_child_targets.iter().next() else {
352374
return None;
353375
};
354376
let mut same_target_value = Some(may_same_target_value);
@@ -357,7 +379,6 @@ fn evaluate_candidate<'tcx>(
357379
&bbs[child],
358380
may_same_target_value,
359381
child_place,
360-
destination,
361382
hoist_discriminant,
362383
) {
363384
same_target_value = None;
@@ -369,13 +390,7 @@ fn evaluate_candidate<'tcx>(
369390
return None;
370391
}
371392
for (value, child) in targets.iter() {
372-
if !verify_candidate_branch(
373-
&bbs[child],
374-
value,
375-
child_place,
376-
destination,
377-
hoist_discriminant,
378-
) {
393+
if !verify_candidate_branch(&bbs[child], value, child_place, hoist_discriminant) {
379394
return None;
380395
}
381396
}
@@ -384,7 +399,7 @@ fn evaluate_candidate<'tcx>(
384399
destination,
385400
child_place,
386401
child_ty,
387-
child_source: child_terminator.source_info,
402+
child_source: first_child_terminator.source_info,
388403
hoist_discriminant,
389404
same_target_value,
390405
})
@@ -394,20 +409,11 @@ fn verify_candidate_branch<'tcx>(
394409
branch: &BasicBlockData<'tcx>,
395410
value: u128,
396411
place: Place<'tcx>,
397-
destination: BasicBlock,
398412
hoist_discriminant: bool,
399413
) -> bool {
400-
// In order for the optimization to be correct, the branch must...
401-
// ...have exactly one statement
402-
if (hoist_discriminant && branch.statements.len() != 1)
403-
|| (!hoist_discriminant && !branch.statements.is_empty())
404-
{
405-
return false;
406-
}
407-
// ...terminate on a `SwitchInt` that invalidates that local
408414
let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } = &branch.terminator().kind
409415
else {
410-
return false;
416+
unreachable!()
411417
};
412418
if hoist_discriminant {
413419
// ...assign the discriminant of `place` in that statement
@@ -428,10 +434,6 @@ fn verify_candidate_branch<'tcx>(
428434
return false;
429435
}
430436
}
431-
// ...fall through to `destination` if the switch misses
432-
if destination != targets.otherwise() {
433-
return false;
434-
}
435437
// ...have a branch for value `value`
436438
let mut iter = targets.iter();
437439
let Some((target_value, _)) = iter.next() else {

0 commit comments

Comments
 (0)