Skip to content

Commit 18e6b5a

Browse files
committed
Use a ConstValue instead.
1 parent 6d5a46b commit 18e6b5a

21 files changed

+419
-152
lines changed

compiler/rustc_mir_transform/src/dataflow_const_prop.rs

Lines changed: 137 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
//! Currently, this pass only propagates scalar values.
44
55
use rustc_const_eval::const_eval::CheckAlignment;
6-
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
6+
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
77
use rustc_data_structures::fx::FxHashMap;
88
use rustc_hir::def::DefKind;
9-
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, ConstValue, InterpResult, Scalar};
9+
use rustc_middle::mir::interpret::{
10+
AllocId, ConstAllocation, ConstValue, GlobalAlloc, InterpResult, Scalar,
11+
};
1012
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
1113
use rustc_middle::mir::*;
12-
use rustc_middle::ty::layout::TyAndLayout;
14+
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1315
use rustc_middle::ty::{self, Ty, TyCtxt};
1416
use rustc_mir_dataflow::value_analysis::{
1517
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
1618
};
1719
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
1820
use rustc_span::def_id::DefId;
1921
use rustc_span::DUMMY_SP;
20-
use rustc_target::abi::{Align, FieldIdx, VariantIdx};
22+
use rustc_target::abi::{Align, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
2123

2224
use crate::MirPass;
2325

@@ -546,110 +548,130 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
546548

547549
fn try_make_constant(
548550
&self,
551+
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
549552
place: Place<'tcx>,
550553
state: &State<FlatSet<Scalar>>,
551554
map: &Map,
552555
) -> Option<ConstantKind<'tcx>> {
553556
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
554557
let place = map.find(place.as_ref())?;
555-
if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) {
556-
Some(ConstantKind::Val(ConstValue::Scalar(value.into()), ty))
558+
let layout = ecx.layout_of(ty).ok()?;
559+
if layout.abi.is_scalar() {
560+
let value = propagatable_scalar(*ecx.tcx, place, state, map)?;
561+
Some(ConstantKind::Val(ConstValue::Scalar(value), ty))
557562
} else {
558-
let valtree = self.try_make_valtree(place, ty, state, map)?;
559-
let constant = ty::Const::new_value(self.patch.tcx, valtree, ty);
560-
Some(ConstantKind::Ty(constant))
563+
let alloc_id = ecx
564+
.intern_with_temp_alloc(layout, |ecx, dest| {
565+
try_write_constant(ecx, dest, place, ty, state, map)
566+
})
567+
.ok()?;
568+
Some(ConstantKind::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty))
561569
}
562570
}
571+
}
563572

564-
fn try_make_valtree(
565-
&self,
566-
place: PlaceIndex,
567-
ty: Ty<'tcx>,
568-
state: &State<FlatSet<Scalar>>,
569-
map: &Map,
570-
) -> Option<ty::ValTree<'tcx>> {
571-
let tcx = self.patch.tcx;
572-
match ty.kind() {
573-
// ZSTs.
574-
ty::FnDef(..) => Some(ty::ValTree::zst()),
575-
576-
// Scalars.
577-
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
578-
if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) {
579-
Some(ty::ValTree::Leaf(value))
580-
} else {
581-
None
582-
}
573+
fn propagatable_scalar<'tcx>(
574+
tcx: TyCtxt<'tcx>,
575+
place: PlaceIndex,
576+
state: &State<FlatSet<Scalar>>,
577+
map: &Map,
578+
) -> Option<Scalar> {
579+
if let FlatSet::Elem(value) = state.get_idx(place, map) {
580+
if let Scalar::Ptr(pointer, _) = value {
581+
let (alloc_id, _) = pointer.into_parts();
582+
match tcx.global_alloc(alloc_id) {
583+
// Do not propagate pointers to functions and vtables as they may
584+
// lose identify during codegen, which is a miscompilation.
585+
GlobalAlloc::Function(_) | GlobalAlloc::VTable(..) => return None,
586+
GlobalAlloc::Memory(_) | GlobalAlloc::Static(_) => {}
583587
}
588+
}
589+
Some(value)
590+
} else {
591+
None
592+
}
593+
}
584594

585-
// Unsupported for now.
586-
ty::Array(_, _) => None,
587-
588-
ty::Tuple(elem_tys) => {
589-
let branches = elem_tys
590-
.iter()
591-
.enumerate()
592-
.map(|(i, ty)| {
593-
let field = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i)))?;
594-
self.try_make_valtree(field, ty, state, map)
595-
})
596-
.collect::<Option<Vec<_>>>()?;
597-
Some(ty::ValTree::Branch(tcx.arena.alloc_from_iter(branches.into_iter())))
595+
#[instrument(level = "trace", skip(ecx, state, map))]
596+
fn try_write_constant<'tcx>(
597+
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
598+
dest: &PlaceTy<'tcx>,
599+
place: PlaceIndex,
600+
ty: Ty<'tcx>,
601+
state: &State<FlatSet<Scalar>>,
602+
map: &Map,
603+
) -> InterpResult<'tcx> {
604+
let layout = ecx.layout_of(ty)?;
605+
match ty.kind() {
606+
// ZSTs. Nothing to do.
607+
ty::FnDef(..) => {}
608+
609+
// Scalars.
610+
_ if layout.abi.is_scalar() => {
611+
let value = propagatable_scalar(*ecx.tcx, place, state, map).ok_or(err_inval!(ConstPropNonsense))?;
612+
ecx.write_immediate(Immediate::Scalar(value), dest)?;
613+
}
614+
// Those are scalars, must be handled above.
615+
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => bug!(),
616+
617+
ty::Tuple(elem_tys) => {
618+
for (i, elem) in elem_tys.iter().enumerate() {
619+
let field = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))).ok_or(err_inval!(ConstPropNonsense))?;
620+
let field_dest = ecx.project_field(dest, i)?;
621+
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
598622
}
623+
}
599624

600-
ty::Adt(def, args) => {
601-
if def.is_union() {
602-
return None;
603-
}
625+
ty::Adt(def, args) => {
626+
if def.is_union() {
627+
throw_inval!(ConstPropNonsense)
628+
}
604629

605-
let (variant_idx, variant_def, variant_place) = if def.is_enum() {
606-
let discr = map.apply(place, TrackElem::Discriminant)?;
607-
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
608-
return None;
609-
};
610-
let discr_bits = discr.assert_bits(discr.size());
611-
let (variant, _) =
612-
def.discriminants(tcx).find(|(_, var)| discr_bits == var.val)?;
613-
let variant_place = map.apply(place, TrackElem::Variant(variant))?;
614-
let variant_int = ty::ValTree::Leaf(variant.as_u32().into());
615-
(Some(variant_int), def.variant(variant), variant_place)
616-
} else {
617-
(None, def.non_enum_variant(), place)
630+
let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
631+
let discr = map.apply(place, TrackElem::Discriminant).ok_or(err_inval!(ConstPropNonsense))?;
632+
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
633+
throw_inval!(ConstPropNonsense)
618634
};
619-
620-
let branches = variant_def
621-
.fields
622-
.iter_enumerated()
623-
.map(|(i, field)| {
624-
let ty = field.ty(tcx, args);
625-
let field = map.apply(variant_place, TrackElem::Field(i))?;
626-
self.try_make_valtree(field, ty, state, map)
627-
})
628-
.collect::<Option<Vec<_>>>()?;
629-
Some(ty::ValTree::Branch(
630-
tcx.arena.alloc_from_iter(variant_idx.into_iter().chain(branches)),
631-
))
635+
let discr_bits = discr.assert_bits(discr.size());
636+
let (variant, _) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val).ok_or(err_inval!(ConstPropNonsense))?;
637+
let variant_place = map.apply(place, TrackElem::Variant(variant)).ok_or(err_inval!(ConstPropNonsense))?;
638+
let variant_dest = ecx.project_downcast(dest, variant)?;
639+
(variant, def.variant(variant), variant_place, variant_dest)
640+
} else {
641+
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
642+
};
643+
644+
for (i, field) in variant_def.fields.iter_enumerated() {
645+
let ty = field.ty(*ecx.tcx, args);
646+
let field = map.apply(variant_place, TrackElem::Field(i)).ok_or(err_inval!(ConstPropNonsense))?;
647+
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
648+
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
632649
}
650+
ecx.write_discriminant(variant_idx, dest)?;
651+
}
652+
653+
// Unsupported for now.
654+
ty::Array(_, _)
633655

634-
// Do not attempt to support indirection in constants.
635-
ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) => None,
636-
637-
ty::Never
638-
| ty::Foreign(..)
639-
| ty::Alias(..)
640-
| ty::Param(_)
641-
| ty::Bound(..)
642-
| ty::Placeholder(..)
643-
| ty::Closure(..)
644-
| ty::Generator(..)
645-
| ty::Dynamic(..) => None,
646-
647-
ty::Error(_)
648-
| ty::Infer(..)
649-
| ty::GeneratorWitness(..)
650-
| ty::GeneratorWitnessMIR(..) => bug!(),
656+
// Do not attempt to support indirection in constants.
657+
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
658+
659+
| ty::Never
660+
| ty::Foreign(..)
661+
| ty::Alias(..)
662+
| ty::Param(_)
663+
| ty::Bound(..)
664+
| ty::Placeholder(..)
665+
| ty::Closure(..)
666+
| ty::Generator(..)
667+
| ty::Dynamic(..) => throw_inval!(ConstPropNonsense),
668+
669+
ty::Error(_) | ty::Infer(..) | ty::GeneratorWitness(..) | ty::GeneratorWitnessMIR(..) => {
670+
bug!()
651671
}
652672
}
673+
674+
Ok(())
653675
}
654676

655677
impl<'mir, 'tcx>
@@ -667,8 +689,13 @@ impl<'mir, 'tcx>
667689
) {
668690
match &statement.kind {
669691
StatementKind::Assign(box (_, rvalue)) => {
670-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
671-
.visit_rvalue(rvalue, location);
692+
OperandCollector {
693+
state,
694+
visitor: self,
695+
ecx: &mut results.analysis.0.ecx,
696+
map: &results.analysis.0.map,
697+
}
698+
.visit_rvalue(rvalue, location);
672699
}
673700
_ => (),
674701
}
@@ -686,7 +713,12 @@ impl<'mir, 'tcx>
686713
// Don't overwrite the assignment if it already uses a constant (to keep the span).
687714
}
688715
StatementKind::Assign(box (place, _)) => {
689-
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
716+
if let Some(value) = self.try_make_constant(
717+
&mut results.analysis.0.ecx,
718+
place,
719+
state,
720+
&results.analysis.0.map,
721+
) {
690722
self.patch.assignments.insert(location, value);
691723
}
692724
}
@@ -701,8 +733,13 @@ impl<'mir, 'tcx>
701733
terminator: &'mir Terminator<'tcx>,
702734
location: Location,
703735
) {
704-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
705-
.visit_terminator(terminator, location);
736+
OperandCollector {
737+
state,
738+
visitor: self,
739+
ecx: &mut results.analysis.0.ecx,
740+
map: &results.analysis.0.map,
741+
}
742+
.visit_terminator(terminator, location);
706743
}
707744
}
708745

@@ -757,6 +794,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
757794
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
758795
state: &'a State<FlatSet<Scalar>>,
759796
visitor: &'a mut Collector<'tcx, 'locals>,
797+
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
760798
map: &'map Map,
761799
}
762800

@@ -769,15 +807,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
769807
location: Location,
770808
) {
771809
if let PlaceElem::Index(local) = elem
772-
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
810+
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
773811
{
774812
self.visitor.patch.before_effect.insert((location, local.into()), value);
775813
}
776814
}
777815

778816
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
779817
if let Some(place) = operand.place() {
780-
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
818+
if let Some(value) =
819+
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
820+
{
781821
self.visitor.patch.before_effect.insert((location, place), value);
782822
} else if !place.projection.is_empty() {
783823
// Try to propagate into `Index` projections.
@@ -802,8 +842,9 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
802842
}
803843

804844
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
805-
unimplemented!()
845+
false
806846
}
847+
807848
fn alignment_check_failed(
808849
_ecx: &InterpCx<'mir, 'tcx, Self>,
809850
_has: Align,

tests/mir-opt/const_debuginfo.main.ConstDebugInfo.diff

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
let _10: std::option::Option<u16>;
4343
scope 7 {
4444
- debug o => _10;
45-
+ debug o => const Option::<u16>::Some(99);
45+
+ debug o => const Option::<u16>::Some(99_u16);
4646
let _17: u32;
4747
let _18: u32;
4848
scope 8 {
@@ -82,7 +82,7 @@
8282
_15 = const false;
8383
_16 = const 123_u32;
8484
StorageLive(_10);
85-
_10 = const Option::<u16>::Some(99);
85+
_10 = const Option::<u16>::Some(99_u16);
8686
_17 = const 32_u32;
8787
_18 = const 32_u32;
8888
StorageLive(_11);
@@ -98,3 +98,7 @@
9898
}
9999
}
100100

101+
alloc10 (size: 4, align: 2) {
102+
01 00 63 00 │ ..c.
103+
}
104+

tests/mir-opt/dataflow-const-prop/checked.main.DataflowConstProp.panic-abort.diff

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
- _6 = CheckedAdd(_4, _5);
4444
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind unreachable];
4545
+ _5 = const 2_i32;
46-
+ _6 = const (3, false);
46+
+ _6 = const (3_i32, false);
4747
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind unreachable];
4848
}
4949

@@ -76,5 +76,13 @@
7676
StorageDead(_1);
7777
return;
7878
}
79+
+ }
80+
+
81+
+ alloc5 (size: 8, align: 4) {
82+
+ 00 00 00 80 01 __ __ __ │ .....░░░
83+
+ }
84+
+
85+
+ alloc4 (size: 8, align: 4) {
86+
+ 03 00 00 00 00 __ __ __ │ .....░░░
7987
}
8088

tests/mir-opt/dataflow-const-prop/checked.main.DataflowConstProp.panic-unwind.diff

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
- _6 = CheckedAdd(_4, _5);
4444
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind continue];
4545
+ _5 = const 2_i32;
46-
+ _6 = const (3, false);
46+
+ _6 = const (3_i32, false);
4747
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind continue];
4848
}
4949

@@ -76,5 +76,13 @@
7676
StorageDead(_1);
7777
return;
7878
}
79+
+ }
80+
+
81+
+ alloc5 (size: 8, align: 4) {
82+
+ 00 00 00 80 01 __ __ __ │ .....░░░
83+
+ }
84+
+
85+
+ alloc4 (size: 8, align: 4) {
86+
+ 03 00 00 00 00 __ __ __ │ .....░░░
7987
}
8088

0 commit comments

Comments
 (0)