@@ -129,18 +129,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
129
129
130
130
let mut patch = MirPatch :: new ( body) ;
131
131
132
- // create temp to store second discriminant in, `_s` in example above
133
- let second_discriminant_temp =
134
- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
132
+ let ( second_discriminant_temp, second_operand) = if opt_data. need_hoist_discriminant {
133
+ // create temp to store second discriminant in, `_s` in example above
134
+ let second_discriminant_temp =
135
+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
135
136
136
- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
137
+ patch. add_statement (
138
+ parent_end,
139
+ StatementKind :: StorageLive ( second_discriminant_temp) ,
140
+ ) ;
137
141
138
- // create assignment of discriminant
139
- patch. add_assign (
140
- parent_end,
141
- Place :: from ( second_discriminant_temp) ,
142
- Rvalue :: Discriminant ( opt_data. child_place ) ,
143
- ) ;
142
+ // create assignment of discriminant
143
+ patch. add_assign (
144
+ parent_end,
145
+ Place :: from ( second_discriminant_temp) ,
146
+ Rvalue :: Discriminant ( opt_data. child_place ) ,
147
+ ) ;
148
+ (
149
+ Some ( second_discriminant_temp) ,
150
+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
151
+ )
152
+ } else {
153
+ ( None , Operand :: Copy ( opt_data. child_place ) )
154
+ } ;
144
155
145
156
// create temp to store inequality comparison between the two discriminants, `_t` in
146
157
// example above
@@ -149,11 +160,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
149
160
let comp_temp = patch. new_temp ( comp_res_type, opt_data. child_source . span ) ;
150
161
patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
151
162
152
- // create inequality comparison between the two discriminants
153
- let comp_rvalue = Rvalue :: BinaryOp (
154
- nequal,
155
- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
156
- ) ;
163
+ // create inequality comparison
164
+ let comp_rvalue =
165
+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
157
166
patch. add_statement (
158
167
parent_end,
159
168
StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -189,8 +198,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
189
198
TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
190
199
) ;
191
200
192
- // generate StorageDead for the second_discriminant_temp not in use anymore
193
- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
201
+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
202
+ // generate StorageDead for the second_discriminant_temp not in use anymore
203
+ patch. add_statement (
204
+ parent_end,
205
+ StatementKind :: StorageDead ( second_discriminant_temp) ,
206
+ ) ;
207
+ }
194
208
195
209
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
196
210
// the switch
@@ -218,6 +232,7 @@ struct OptimizationData<'tcx> {
218
232
child_place : Place < ' tcx > ,
219
233
child_ty : Ty < ' tcx > ,
220
234
child_source : SourceInfo ,
235
+ need_hoist_discriminant : bool ,
221
236
}
222
237
223
238
fn evaluate_candidate < ' tcx > (
@@ -231,70 +246,128 @@ fn evaluate_candidate<'tcx>(
231
246
return None ;
232
247
} ;
233
248
let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
234
- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
235
- // Someone could write code like this:
236
- // ```rust
237
- // let Q = val;
238
- // if discriminant(P) == otherwise {
239
- // let ptr = &mut Q as *mut _ as *mut u8;
240
- // // It may be difficult for us to effectively determine whether values are valid.
241
- // // Invalid values can come from all sorts of corners.
242
- // unsafe { *ptr = 10; }
243
- // }
244
- //
245
- // match P {
246
- // A => match Q {
247
- // A => {
248
- // // code
249
- // }
250
- // _ => {
251
- // // don't use Q
252
- // }
253
- // }
254
- // _ => {
255
- // // don't use Q
256
- // }
257
- // };
258
- // ```
259
- //
260
- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
261
- // of an invalid value, which is UB.
262
- // In order to fix this, **we would either need to show that the discriminant computation of
263
- // `place` is computed in all branches**.
264
- // FIXME(#95162) For the moment, we adopt a conservative approach and
265
- // consider only the `otherwise` branch has no statements and an unreachable terminator.
266
- return None ;
267
- }
268
249
let ( _, child) = targets. iter ( ) . next ( ) ?;
269
- let child_terminator = & bbs[ child] . terminator ( ) ;
270
- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
271
- & child_terminator. kind
250
+
251
+ let Terminator {
252
+ kind : TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } ,
253
+ source_info,
254
+ } = bbs[ child] . terminator ( )
272
255
else {
273
256
return None ;
274
257
} ;
275
258
let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
276
259
if child_ty != parent_ty {
277
260
return None ;
278
261
}
279
- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind ) else {
262
+
263
+ // We only handle:
264
+ // ```
265
+ // bb4: {
266
+ // _8 = discriminant((_3.1: Enum1));
267
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
268
+ // }
269
+ // ```
270
+ // and
271
+ // ```
272
+ // bb2: {
273
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
274
+ // }
275
+ // ```
276
+ if bbs[ child] . statements . len ( ) > 1 {
280
277
return None ;
278
+ }
279
+
280
+ // When thie BB has exactly one statement, this statement should be discriminant.
281
+ let need_hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
282
+ let child_place = if need_hoist_discriminant {
283
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
284
+ // Someone could write code like this:
285
+ // ```rust
286
+ // let Q = val;
287
+ // if discriminant(P) == otherwise {
288
+ // let ptr = &mut Q as *mut _ as *mut u8;
289
+ // // It may be difficult for us to effectively determine whether values are valid.
290
+ // // Invalid values can come from all sorts of corners.
291
+ // unsafe { *ptr = 10; }
292
+ // }
293
+ //
294
+ // match P {
295
+ // A => match Q {
296
+ // A => {
297
+ // // code
298
+ // }
299
+ // _ => {
300
+ // // don't use Q
301
+ // }
302
+ // }
303
+ // _ => {
304
+ // // don't use Q
305
+ // }
306
+ // };
307
+ // ```
308
+ //
309
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
310
+ // invalid value, which is UB.
311
+ // In order to fix this, **we would either need to show that the discriminant computation of
312
+ // `place` is computed in all branches**.
313
+ // FIXME(#95162) For the moment, we adopt a conservative approach and
314
+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
315
+ return None ;
316
+ }
317
+ // Handle:
318
+ // ```
319
+ // bb4: {
320
+ // _8 = discriminant((_3.1: Enum1));
321
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
322
+ // }
323
+ // ```
324
+ let [
325
+ Statement {
326
+ kind : StatementKind :: Assign ( box ( _, Rvalue :: Discriminant ( child_place) ) ) ,
327
+ ..
328
+ } ,
329
+ ] = bbs[ child] . statements . as_slice ( )
330
+ else {
331
+ return None ;
332
+ } ;
333
+ * child_place
334
+ } else {
335
+ // Handle:
336
+ // ```
337
+ // bb2: {
338
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
339
+ // }
340
+ // ```
341
+ let Operand :: Copy ( child_place) = child_discr else {
342
+ return None ;
343
+ } ;
344
+ * child_place
281
345
} ;
282
- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
283
- return None ;
346
+ let destination = if need_hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( )
347
+ {
348
+ child_targets. otherwise ( )
349
+ } else {
350
+ targets. otherwise ( )
284
351
} ;
285
- let destination = child_targets. otherwise ( ) ;
286
352
287
353
// Verify that the optimization is legal for each branch
288
354
for ( value, child) in targets. iter ( ) {
289
- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
355
+ if !verify_candidate_branch (
356
+ & bbs[ child] ,
357
+ value,
358
+ child_place,
359
+ destination,
360
+ need_hoist_discriminant,
361
+ ) {
290
362
return None ;
291
363
}
292
364
}
293
365
Some ( OptimizationData {
294
366
destination,
295
- child_place : * child_place ,
367
+ child_place,
296
368
child_ty,
297
- child_source : child_terminator. source_info ,
369
+ child_source : * source_info,
370
+ need_hoist_discriminant,
298
371
} )
299
372
}
300
373
@@ -303,31 +376,48 @@ fn verify_candidate_branch<'tcx>(
303
376
value : u128 ,
304
377
place : Place < ' tcx > ,
305
378
destination : BasicBlock ,
379
+ need_hoist_discriminant : bool ,
306
380
) -> bool {
307
- // In order for the optimization to be correct, the branch must...
308
- // ...have exactly one statement
309
- if let [ statement] = branch. statements . as_slice ( )
310
- // ...assign the discriminant of `place` in that statement
311
- && let StatementKind :: Assign ( boxed) = & statement. kind
312
- && let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed
313
- && * from_place == place
314
- // ...make that assignment to a local
315
- && discr_place. projection . is_empty ( )
316
- // ...terminate on a `SwitchInt` that invalidates that local
317
- && let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } =
318
- & branch. terminator ( ) . kind
319
- && * switch_op == Operand :: Move ( * discr_place)
320
- // ...fall through to `destination` if the switch misses
321
- && destination == targets. otherwise ( )
322
- // ...have a branch for value `value`
323
- && let mut iter = targets. iter ( )
324
- && let Some ( ( target_value, _) ) = iter. next ( )
325
- && target_value == value
326
- // ...and have no more branches
327
- && iter. next ( ) . is_none ( )
328
- {
329
- true
381
+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
382
+ let TerminatorKind :: SwitchInt { discr : switch_op, targets } = & branch. terminator ( ) . kind else {
383
+ return false ;
384
+ } ;
385
+ if need_hoist_discriminant {
386
+ // If we need hoist discriminant, the branch must have exactly one statement.
387
+ let [ statement] = branch. statements . as_slice ( ) else {
388
+ return false ;
389
+ } ;
390
+ // The statement must assign the discriminant of `place`.
391
+ let StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( from_place) ) ) =
392
+ statement. kind
393
+ else {
394
+ return false ;
395
+ } ;
396
+ if from_place != place {
397
+ return false ;
398
+ }
399
+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
400
+ if !discr_place. projection . is_empty ( ) || * switch_op != Operand :: Move ( discr_place) {
401
+ return false ;
402
+ }
330
403
} else {
331
- false
404
+ // If we don't need hoist discriminant, the branch must not have any statements.
405
+ if !branch. statements . is_empty ( ) {
406
+ return false ;
407
+ }
408
+ // The place on `SwitchInt` must be the same.
409
+ if * switch_op != Operand :: Copy ( place) {
410
+ return false ;
411
+ }
332
412
}
413
+ // It must fall through to `destination` if the switch misses.
414
+ if destination != targets. otherwise ( ) {
415
+ return false ;
416
+ }
417
+ // It must have exactly one branch for value `value` and have no more branches.
418
+ let mut iter = targets. iter ( ) ;
419
+ let ( Some ( ( target_value, _) ) , None ) = ( iter. next ( ) , iter. next ( ) ) else {
420
+ return false ;
421
+ } ;
422
+ target_value == value
333
423
}
0 commit comments