@@ -135,18 +135,29 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
135
135
136
136
let mut patch = MirPatch :: new ( body) ;
137
137
138
- // create temp to store second discriminant in, `_s` in example above
139
- let second_discriminant_temp =
140
- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
138
+ let ( second_discriminant_temp, second_operand) = if opt_data. hoist_discriminant {
139
+ // create temp to store second discriminant in, `_s` in example above
140
+ let second_discriminant_temp =
141
+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
141
142
142
- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
143
+ patch. add_statement (
144
+ parent_end,
145
+ StatementKind :: StorageLive ( second_discriminant_temp) ,
146
+ ) ;
143
147
144
- // create assignment of discriminant
145
- patch. add_assign (
146
- parent_end,
147
- Place :: from ( second_discriminant_temp) ,
148
- Rvalue :: Discriminant ( opt_data. child_place ) ,
149
- ) ;
148
+ // create assignment of discriminant
149
+ patch. add_assign (
150
+ parent_end,
151
+ Place :: from ( second_discriminant_temp) ,
152
+ Rvalue :: Discriminant ( opt_data. child_place ) ,
153
+ ) ;
154
+ (
155
+ Some ( second_discriminant_temp) ,
156
+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
157
+ )
158
+ } else {
159
+ ( None , Operand :: Copy ( opt_data. child_place ) )
160
+ } ;
150
161
151
162
// create temp to store inequality comparison between the two discriminants, `_t` in
152
163
// example above
@@ -156,10 +167,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
156
167
patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
157
168
158
169
// create inequality comparison between the two discriminants
159
- let comp_rvalue = Rvalue :: BinaryOp (
160
- nequal,
161
- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
162
- ) ;
170
+ let comp_rvalue =
171
+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
163
172
patch. add_statement (
164
173
parent_end,
165
174
StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -194,8 +203,13 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
194
203
TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
195
204
) ;
196
205
197
- // generate StorageDead for the second_discriminant_temp not in use anymore
198
- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
206
+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
207
+ // generate StorageDead for the second_discriminant_temp not in use anymore
208
+ patch. add_statement (
209
+ parent_end,
210
+ StatementKind :: StorageDead ( second_discriminant_temp) ,
211
+ ) ;
212
+ }
199
213
200
214
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
201
215
// the switch
@@ -271,6 +285,7 @@ struct OptimizationData<'tcx> {
271
285
child_place : Place < ' tcx > ,
272
286
child_ty : Ty < ' tcx > ,
273
287
child_source : SourceInfo ,
288
+ hoist_discriminant : bool ,
274
289
}
275
290
276
291
fn evaluate_candidate < ' tcx > (
@@ -284,38 +299,6 @@ fn evaluate_candidate<'tcx>(
284
299
return None ;
285
300
} ;
286
301
let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
287
- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
288
- // Someone could write code like this:
289
- // ```rust
290
- // let Q = val;
291
- // if discriminant(P) == otherwise {
292
- // let ptr = &mut Q as *mut _ as *mut u8;
293
- // // Any invalid value for the type. It is possible to be opaque, such as in other functions.
294
- // unsafe { *ptr = 10; }
295
- // }
296
- //
297
- // match P {
298
- // A => match Q {
299
- // A => {
300
- // // code
301
- // }
302
- // _ => {
303
- // // don't use Q
304
- // }
305
- // }
306
- // _ => {
307
- // // don't use Q
308
- // }
309
- // };
310
- // ```
311
- //
312
- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
313
- // invalid value, which is UB.
314
- // In order to fix this, we would either need to show that the discriminant computation of
315
- // `place` is computed in all branches.
316
- // So we need the `otherwise` branch has no statements and an unreachable terminator.
317
- return None ;
318
- }
319
302
let ( _, child) = targets. iter ( ) . next ( ) ?;
320
303
let child_terminator = & bbs[ child] . terminator ( ) ;
321
304
let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
@@ -327,31 +310,89 @@ fn evaluate_candidate<'tcx>(
327
310
if child_ty != parent_ty {
328
311
return None ;
329
312
}
330
- let Some ( StatementKind :: Assign ( boxed ) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x . kind ) else {
313
+ if bbs[ child] . statements . len ( ) > 1 {
331
314
return None ;
315
+ }
316
+ let hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
317
+ let child_place = if hoist_discriminant {
318
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
319
+ // Someone could write code like this:
320
+ // ```rust
321
+ // let Q = val;
322
+ // if discriminant(P) == otherwise {
323
+ // let ptr = &mut Q as *mut _ as *mut u8;
324
+ // // Any invalid value for the type. It is possible to be opaque, such as in other functions.
325
+ // unsafe { *ptr = 10; }
326
+ // }
327
+ //
328
+ // match P {
329
+ // A => match Q {
330
+ // A => {
331
+ // // code
332
+ // }
333
+ // _ => {
334
+ // // don't use Q
335
+ // }
336
+ // }
337
+ // _ => {
338
+ // // don't use Q
339
+ // }
340
+ // };
341
+ // ```
342
+ //
343
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
344
+ // invalid value, which is UB.
345
+ // In order to fix this, we would either need to show that the discriminant computation of
346
+ // `place` is computed in all branches.
347
+ // So we need the `otherwise` branch has no statements and an unreachable terminator.
348
+ return None ;
349
+ }
350
+ let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind )
351
+ else {
352
+ return None ;
353
+ } ;
354
+ let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
355
+ return None ;
356
+ } ;
357
+ // Verify that the optimization is legal in general
358
+ // We can hoist evaluating the child discriminant out of the branch
359
+ if !may_hoist ( tcx, body, * child_place) {
360
+ return None ;
361
+ }
362
+ * child_place
363
+ } else {
364
+ let TerminatorKind :: SwitchInt { discr, .. } = & bbs[ child] . terminator ( ) . kind else {
365
+ return None ;
366
+ } ;
367
+ let Operand :: Copy ( child_place) = discr else {
368
+ return None ;
369
+ } ;
370
+ * child_place
332
371
} ;
333
- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
334
- return None ;
372
+ let destination = if hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
373
+ child_targets. otherwise ( )
374
+ } else {
375
+ targets. otherwise ( )
335
376
} ;
336
- let destination = child_targets. otherwise ( ) ;
337
-
338
- // Verify that the optimization is legal in general
339
- // We can hoist evaluating the child discriminant out of the branch
340
- if !may_hoist ( tcx, body, * child_place) {
341
- return None ;
342
- }
343
377
344
378
// Verify that the optimization is legal for each branch
345
379
for ( value, child) in targets. iter ( ) {
346
- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
380
+ if !verify_candidate_branch (
381
+ & bbs[ child] ,
382
+ value,
383
+ child_place,
384
+ destination,
385
+ hoist_discriminant,
386
+ ) {
347
387
return None ;
348
388
}
349
389
}
350
390
Some ( OptimizationData {
351
391
destination,
352
- child_place : * child_place ,
392
+ child_place,
353
393
child_ty,
354
394
child_source : child_terminator. source_info ,
395
+ hoist_discriminant,
355
396
} )
356
397
}
357
398
@@ -360,29 +401,38 @@ fn verify_candidate_branch<'tcx>(
360
401
value : u128 ,
361
402
place : Place < ' tcx > ,
362
403
destination : BasicBlock ,
404
+ hoist_discriminant : bool ,
363
405
) -> bool {
364
406
// In order for the optimization to be correct, the branch must...
365
407
// ...have exactly one statement
366
- if branch. statements . len ( ) != 1 {
367
- return false ;
368
- }
369
- // ...assign the discriminant of `place` in that statement
370
- let StatementKind :: Assign ( boxed) = & branch. statements [ 0 ] . kind else { return false } ;
371
- let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed else { return false } ;
372
- if * from_place != place {
373
- return false ;
374
- }
375
- // ...make that assignment to a local
376
- if discr_place. projection . len ( ) != 0 {
408
+ if ( hoist_discriminant && branch. statements . len ( ) != 1 )
409
+ || ( !hoist_discriminant && !branch. statements . is_empty ( ) )
410
+ {
377
411
return false ;
378
412
}
379
413
// ...terminate on a `SwitchInt` that invalidates that local
380
414
let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } = & branch. terminator ( ) . kind
381
415
else {
382
416
return false ;
383
417
} ;
384
- if * switch_op != Operand :: Move ( * discr_place) {
385
- return false ;
418
+ if hoist_discriminant {
419
+ // ...assign the discriminant of `place` in that statement
420
+ let StatementKind :: Assign ( boxed) = & branch. statements [ 0 ] . kind else { return false } ;
421
+ let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed else { return false } ;
422
+ if * from_place != place {
423
+ return false ;
424
+ }
425
+ // ...make that assignment to a local
426
+ if discr_place. projection . len ( ) != 0 {
427
+ return false ;
428
+ }
429
+ if * switch_op != Operand :: Move ( * discr_place) {
430
+ return false ;
431
+ }
432
+ } else {
433
+ if * switch_op != Operand :: Copy ( place) {
434
+ return false ;
435
+ }
386
436
}
387
437
// ...fall through to `destination` if the switch misses
388
438
if destination != targets. otherwise ( ) {
@@ -397,7 +447,7 @@ fn verify_candidate_branch<'tcx>(
397
447
return false ;
398
448
}
399
449
// ...and have no more branches
400
- if let Some ( _ ) = iter. next ( ) {
450
+ if iter. next ( ) . is_some ( ) {
401
451
return false ;
402
452
}
403
453
true
0 commit comments