@@ -3,8 +3,10 @@ use std::iter;
3
3
use rustc_index:: IndexSlice ;
4
4
use rustc_middle:: mir:: patch:: MirPatch ;
5
5
use rustc_middle:: mir:: * ;
6
+ use rustc_middle:: ty:: layout:: { IntegerExt , TyAndLayout } ;
6
7
use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
7
- use rustc_target:: abi:: Size ;
8
+ use rustc_target:: abi:: Integer ;
9
+ use rustc_type_ir:: TyKind :: * ;
8
10
9
11
use super :: simplify:: simplify_cfg;
10
12
@@ -42,10 +44,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
42
44
should_cleanup = true ;
43
45
continue ;
44
46
}
45
- // unsound: https://github.com/rust-lang/rust/issues/124150
46
- if tcx. sess . opts . unstable_opts . unsound_mir_opts
47
- && SimplifyToExp :: default ( ) . simplify ( tcx, body, bb_idx, param_env) . is_some ( )
48
- {
47
+ if SimplifyToExp :: default ( ) . simplify ( tcx, body, bb_idx, param_env) . is_some ( ) {
49
48
should_cleanup = true ;
50
49
continue ;
51
50
}
@@ -264,33 +263,56 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
264
263
}
265
264
}
266
265
266
+ /// Check if the cast constant using `IntToInt` is equal to the target constant.
267
+ fn can_cast (
268
+ tcx : TyCtxt < ' _ > ,
269
+ src_val : impl Into < u128 > ,
270
+ src_layout : TyAndLayout < ' _ > ,
271
+ cast_ty : Ty < ' _ > ,
272
+ target_scalar : ScalarInt ,
273
+ ) -> bool {
274
+ let from_scalar = ScalarInt :: try_from_uint ( src_val. into ( ) , src_layout. size ) . unwrap ( ) ;
275
+ let v = match src_layout. ty . kind ( ) {
276
+ Uint ( _) => from_scalar. to_uint ( src_layout. size ) ,
277
+ Int ( _) => from_scalar. to_int ( src_layout. size ) as u128 ,
278
+ _ => unreachable ! ( "invalid int" ) ,
279
+ } ;
280
+ let size = match * cast_ty. kind ( ) {
281
+ Int ( t) => Integer :: from_int_ty ( & tcx, t) . size ( ) ,
282
+ Uint ( t) => Integer :: from_uint_ty ( & tcx, t) . size ( ) ,
283
+ _ => unreachable ! ( "invalid int" ) ,
284
+ } ;
285
+ let v = size. truncate ( v) ;
286
+ let cast_scalar = ScalarInt :: try_from_uint ( v, size) . unwrap ( ) ;
287
+ cast_scalar == target_scalar
288
+ }
289
+
267
290
#[ derive( Default ) ]
268
291
struct SimplifyToExp {
269
- transfrom_types : Vec < TransfromType > ,
292
+ transfrom_kinds : Vec < TransfromKind > ,
270
293
}
271
294
272
295
#[ derive( Clone , Copy ) ]
273
- enum CompareType < ' tcx , ' a > {
296
+ enum ExpectedTransformKind < ' tcx , ' a > {
274
297
/// Identical statements.
275
298
Same ( & ' a StatementKind < ' tcx > ) ,
276
299
/// Assignment statements have the same value.
277
- Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
300
+ SameByEq { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , scalar : ScalarInt } ,
278
301
/// Enum variant comparison type.
279
- Discr { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , is_signed : bool } ,
302
+ Cast { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > } ,
280
303
}
281
304
282
- enum TransfromType {
305
+ enum TransfromKind {
283
306
Same ,
284
- Eq ,
285
- Discr ,
307
+ Cast ,
286
308
}
287
309
288
- impl From < CompareType < ' _ , ' _ > > for TransfromType {
289
- fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
310
+ impl From < ExpectedTransformKind < ' _ , ' _ > > for TransfromKind {
311
+ fn from ( compare_type : ExpectedTransformKind < ' _ , ' _ > ) -> Self {
290
312
match compare_type {
291
- CompareType :: Same ( _) => TransfromType :: Same ,
292
- CompareType :: Eq ( _ , _ , _ ) => TransfromType :: Eq ,
293
- CompareType :: Discr { .. } => TransfromType :: Discr ,
313
+ ExpectedTransformKind :: Same ( _) => TransfromKind :: Same ,
314
+ ExpectedTransformKind :: SameByEq { .. } => TransfromKind :: Same ,
315
+ ExpectedTransformKind :: Cast { .. } => TransfromKind :: Cast ,
294
316
}
295
317
}
296
318
}
@@ -354,7 +376,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
354
376
return None ;
355
377
}
356
378
let mut target_iter = targets. iter ( ) ;
357
- let ( first_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
379
+ let ( first_case_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
358
380
let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
359
381
// Check that destinations are identical, and if not, then don't optimize this block
360
382
if !targets
@@ -364,24 +386,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
364
386
return None ;
365
387
}
366
388
367
- let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
389
+ let discr_layout = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) ;
368
390
let first_stmts = & bbs[ first_target] . statements ;
369
- let ( second_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
391
+ let ( second_case_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
370
392
let second_stmts = & bbs[ second_target] . statements ;
371
393
if first_stmts. len ( ) != second_stmts. len ( ) {
372
394
return None ;
373
395
}
374
396
375
- fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
376
- l. to_bits_unchecked ( ) == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . to_bits_unchecked ( )
377
- }
378
-
379
397
// We first compare the two branches, and then the other branches need to fulfill the same conditions.
380
- let mut compare_types = Vec :: new ( ) ;
398
+ let mut expected_transform_kinds = Vec :: new ( ) ;
381
399
for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
382
400
let compare_type = match ( & f. kind , & s. kind ) {
383
401
// If two statements are exactly the same, we can optimize.
384
- ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
402
+ ( f_s, s_s) if f_s == s_s => ExpectedTransformKind :: Same ( f_s) ,
385
403
386
404
// If two statements are assignments with the match values to the same place, we can optimize.
387
405
(
@@ -395,22 +413,29 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
395
413
f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
396
414
s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
397
415
) {
398
- ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
399
- // Enum variants can also be simplified to an assignment statement if their values are equal.
400
- // We need to consider both unsigned and signed scenarios here.
416
+ ( Some ( f) , Some ( s) ) if f == s => ExpectedTransformKind :: SameByEq {
417
+ place : lhs_f,
418
+ ty : f_c. const_ . ty ( ) ,
419
+ scalar : f,
420
+ } ,
421
+ // Enum variants can also be simplified to an assignment statement,
422
+ // if we can use `IntToInt` cast to get an equal value.
401
423
( Some ( f) , Some ( s) )
402
- if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
403
- && int_equal ( f, first_val, discr_size)
404
- && int_equal ( s, second_val, discr_size) )
405
- || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
406
- && Some ( s)
407
- == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
424
+ if ( can_cast (
425
+ tcx,
426
+ first_case_val,
427
+ discr_layout,
428
+ f_c. const_ . ty ( ) ,
429
+ f,
430
+ ) && can_cast (
431
+ tcx,
432
+ second_case_val,
433
+ discr_layout,
434
+ f_c. const_ . ty ( ) ,
435
+ s,
436
+ ) ) =>
408
437
{
409
- CompareType :: Discr {
410
- place : lhs_f,
411
- ty : f_c. const_ . ty ( ) ,
412
- is_signed : f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
413
- }
438
+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_c. const_ . ty ( ) }
414
439
}
415
440
_ => {
416
441
return None ;
@@ -421,47 +446,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
421
446
// Otherwise we cannot optimize. Try another block.
422
447
_ => return None ,
423
448
} ;
424
- compare_types . push ( compare_type) ;
449
+ expected_transform_kinds . push ( compare_type) ;
425
450
}
426
451
427
452
// All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
428
453
for ( other_val, other_target) in target_iter {
429
454
let other_stmts = & bbs[ other_target] . statements ;
430
- if compare_types . len ( ) != other_stmts. len ( ) {
455
+ if expected_transform_kinds . len ( ) != other_stmts. len ( ) {
431
456
return None ;
432
457
}
433
- for ( f, s) in iter:: zip ( & compare_types , other_stmts) {
458
+ for ( f, s) in iter:: zip ( & expected_transform_kinds , other_stmts) {
434
459
match ( * f, & s. kind ) {
435
- ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
460
+ ( ExpectedTransformKind :: Same ( f_s) , s_s) if f_s == s_s => { }
436
461
(
437
- CompareType :: Eq ( lhs_f, f_ty, val ) ,
462
+ ExpectedTransformKind :: SameByEq { place : lhs_f, ty : f_ty, scalar } ,
438
463
StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
439
464
) if lhs_f == lhs_s
440
465
&& s_c. const_ . ty ( ) == f_ty
441
- && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val ) => { }
466
+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( scalar ) => { }
442
467
(
443
- CompareType :: Discr { place : lhs_f, ty : f_ty, is_signed } ,
468
+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_ty } ,
444
469
StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
445
- ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
446
- let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
447
- return None ;
448
- } ;
449
- if is_signed
450
- && s_c. const_ . ty ( ) . is_signed ( )
451
- && int_equal ( f, other_val, discr_size)
452
- {
453
- continue ;
454
- }
455
- if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
456
- continue ;
457
- }
458
- return None ;
459
- }
470
+ ) if let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env)
471
+ && lhs_f == lhs_s
472
+ && s_c. const_ . ty ( ) == f_ty
473
+ && can_cast ( tcx, other_val, discr_layout, f_ty, f) => { }
460
474
_ => return None ,
461
475
}
462
476
}
463
477
}
464
- self . transfrom_types = compare_types . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
478
+ self . transfrom_kinds = expected_transform_kinds . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
465
479
Some ( ( ) )
466
480
}
467
481
@@ -479,13 +493,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
479
493
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
480
494
let first = & bbs[ first] ;
481
495
482
- for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
496
+ for ( t, s) in iter:: zip ( & self . transfrom_kinds , & first. statements ) {
483
497
match ( t, & s. kind ) {
484
- ( TransfromType :: Same , _ ) | ( TransfromType :: Eq , _) => {
498
+ ( TransfromKind :: Same , _) => {
485
499
patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
486
500
}
487
501
(
488
- TransfromType :: Discr ,
502
+ TransfromKind :: Cast ,
489
503
StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
490
504
) => {
491
505
let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
0 commit comments