Skip to content

Commit dde15ce

Browse files
committed
interpret: reset provenance on typed copies
1 parent 5611395 commit dde15ce

23 files changed

+489
-135
lines changed

compiler/rustc_const_eval/src/const_eval/eval_queries.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ fn eval_body_using_ecx<'tcx, R: InterpretationResult<'tcx>>(
9494
let intern_result = intern_const_alloc_recursive(ecx, intern_kind, &ret);
9595

9696
// Since evaluation had no errors, validate the resulting constant.
97-
const_validate_mplace(&ecx, &ret, cid)?;
97+
const_validate_mplace(ecx, &ret, cid)?;
9898

9999
// Only report this after validation, as validaiton produces much better diagnostics.
100100
// FIXME: ensure validation always reports this and stop making interning care about it.
@@ -391,7 +391,7 @@ fn eval_in_interpreter<'tcx, R: InterpretationResult<'tcx>>(
391391

392392
#[inline(always)]
393393
fn const_validate_mplace<'tcx>(
394-
ecx: &InterpCx<'tcx, CompileTimeMachine<'tcx>>,
394+
ecx: &mut InterpCx<'tcx, CompileTimeMachine<'tcx>>,
395395
mplace: &MPlaceTy<'tcx>,
396396
cid: GlobalId<'tcx>,
397397
) -> Result<(), ErrorHandled> {

compiler/rustc_const_eval/src/interpret/memory.rs

+26-12
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
99
use std::assert_matches::assert_matches;
1010
use std::borrow::Cow;
11-
use std::cell::Cell;
1211
use std::collections::VecDeque;
13-
use std::{fmt, ptr};
12+
use std::{fmt, mem, ptr};
1413

1514
use rustc_ast::Mutability;
1615
use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
@@ -118,7 +117,7 @@ pub struct Memory<'tcx, M: Machine<'tcx>> {
118117
/// This stores whether we are currently doing reads purely for the purpose of validation.
119118
/// Those reads do not trigger the machine's hooks for memory reads.
120119
/// Needless to say, this must only be set with great care!
121-
validation_in_progress: Cell<bool>,
120+
validation_in_progress: bool,
122121
}
123122

124123
/// A reference to some allocation that was already bounds-checked for the given region
@@ -145,7 +144,7 @@ impl<'tcx, M: Machine<'tcx>> Memory<'tcx, M> {
145144
alloc_map: M::MemoryMap::default(),
146145
extra_fn_ptr_map: FxIndexMap::default(),
147146
dead_alloc_map: FxIndexMap::default(),
148-
validation_in_progress: Cell::new(false),
147+
validation_in_progress: false,
149148
}
150149
}
151150

@@ -682,15 +681,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
682681
// We want to call the hook on *all* accesses that involve an AllocId, including zero-sized
683682
// accesses. That means we cannot rely on the closure above or the `Some` branch below. We
684683
// do this after `check_and_deref_ptr` to ensure some basic sanity has already been checked.
685-
if !self.memory.validation_in_progress.get() {
684+
if !self.memory.validation_in_progress {
686685
if let Ok((alloc_id, ..)) = self.ptr_try_get_alloc_id(ptr, size_i64) {
687686
M::before_alloc_read(self, alloc_id)?;
688687
}
689688
}
690689

691690
if let Some((alloc_id, offset, prov, alloc)) = ptr_and_alloc {
692691
let range = alloc_range(offset, size);
693-
if !self.memory.validation_in_progress.get() {
692+
if !self.memory.validation_in_progress {
694693
M::before_memory_read(
695694
self.tcx,
696695
&self.machine,
@@ -766,11 +765,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
766765
let parts = self.get_ptr_access(ptr, size)?;
767766
if let Some((alloc_id, offset, prov)) = parts {
768767
let tcx = self.tcx;
768+
let validation_in_progress = self.memory.validation_in_progress;
769769
// FIXME: can we somehow avoid looking up the allocation twice here?
770770
// We cannot call `get_raw_mut` inside `check_and_deref_ptr` as that would duplicate `&mut self`.
771771
let (alloc, machine) = self.get_alloc_raw_mut(alloc_id)?;
772772
let range = alloc_range(offset, size);
773-
M::before_memory_write(tcx, machine, &mut alloc.extra, (alloc_id, prov), range)?;
773+
if !validation_in_progress {
774+
M::before_memory_write(tcx, machine, &mut alloc.extra, (alloc_id, prov), range)?;
775+
}
774776
Ok(Some(AllocRefMut { alloc, range, tcx: *tcx, alloc_id }))
775777
} else {
776778
Ok(None)
@@ -1014,16 +1016,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
10141016
///
10151017
/// We do this so Miri's allocation access tracking does not show the validation
10161018
/// reads as spurious accesses.
1017-
pub fn run_for_validation<R>(&self, f: impl FnOnce() -> R) -> R {
1019+
pub fn run_for_validation<R>(&mut self, f: impl FnOnce(&mut Self) -> R) -> R {
10181020
// This deliberately uses `==` on `bool` to follow the pattern
10191021
// `assert!(val.replace(new) == old)`.
10201022
assert!(
1021-
self.memory.validation_in_progress.replace(true) == false,
1023+
mem::replace(&mut self.memory.validation_in_progress, true) == false,
10221024
"`validation_in_progress` was already set"
10231025
);
1024-
let res = f();
1026+
let res = f(self);
10251027
assert!(
1026-
self.memory.validation_in_progress.replace(false) == true,
1028+
mem::replace(&mut self.memory.validation_in_progress, false) == true,
10271029
"`validation_in_progress` was unset by someone else"
10281030
);
10291031
res
@@ -1115,6 +1117,10 @@ impl<'a, 'tcx, M: Machine<'tcx>> std::fmt::Debug for DumpAllocs<'a, 'tcx, M> {
11151117
impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes>
11161118
AllocRefMut<'a, 'tcx, Prov, Extra, Bytes>
11171119
{
1120+
pub fn as_ref<'b>(&'b self) -> AllocRef<'b, 'tcx, Prov, Extra, Bytes> {
1121+
AllocRef { alloc: self.alloc, range: self.range, tcx: self.tcx, alloc_id: self.alloc_id }
1122+
}
1123+
11181124
/// `range` is relative to this allocation reference, not the base of the allocation.
11191125
pub fn write_scalar(&mut self, range: AllocRange, val: Scalar<Prov>) -> InterpResult<'tcx> {
11201126
let range = self.range.subrange(range);
@@ -1137,6 +1143,14 @@ impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes>
11371143
.write_uninit(&self.tcx, self.range)
11381144
.map_err(|e| e.to_interp_error(self.alloc_id))?)
11391145
}
1146+
1147+
/// Remove all provenance in the reference range.
1148+
pub fn clear_provenance(&mut self) -> InterpResult<'tcx> {
1149+
Ok(self
1150+
.alloc
1151+
.clear_provenance(&self.tcx, self.range)
1152+
.map_err(|e| e.to_interp_error(self.alloc_id))?)
1153+
}
11401154
}
11411155

11421156
impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes> AllocRef<'a, 'tcx, Prov, Extra, Bytes> {
@@ -1278,7 +1292,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
12781292
};
12791293
let src_alloc = self.get_alloc_raw(src_alloc_id)?;
12801294
let src_range = alloc_range(src_offset, size);
1281-
assert!(!self.memory.validation_in_progress.get(), "we can't be copying during validation");
1295+
assert!(!self.memory.validation_in_progress, "we can't be copying during validation");
12821296
M::before_memory_read(
12831297
tcx,
12841298
&self.machine,

compiler/rustc_const_eval/src/interpret/operand.rs

+14
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ impl<Prov: Provenance> Immediate<Prov> {
137137
}
138138
}
139139
}
140+
141+
pub fn clear_provenance<'tcx>(&mut self) -> InterpResult<'tcx> {
142+
match self {
143+
Immediate::Scalar(s) => {
144+
s.clear_provenance()?;
145+
}
146+
Immediate::ScalarPair(a, b) => {
147+
a.clear_provenance()?;
148+
b.clear_provenance()?;
149+
}
150+
Immediate::Uninit => {}
151+
}
152+
Ok(())
153+
}
140154
}
141155

142156
// ScalarPair needs a type to interpret, so we often have an immediate and a type together

compiler/rustc_const_eval/src/interpret/place.rs

+44-25
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,9 @@ where
603603
if M::enforce_validity(self, dest.layout()) {
604604
// Data got changed, better make sure it matches the type!
605605
self.validate_operand(
606-
&dest.to_op(self)?,
606+
&dest.to_place(),
607607
M::enforce_validity_recursively(self, dest.layout()),
608+
/*reset_provenance*/ true,
608609
)?;
609610
}
610611

@@ -634,7 +635,7 @@ where
634635
/// Write an immediate to a place.
635636
/// If you use this you are responsible for validating that things got copied at the
636637
/// right type.
637-
fn write_immediate_no_validate(
638+
pub(super) fn write_immediate_no_validate(
638639
&mut self,
639640
src: Immediate<M::Provenance>,
640641
dest: &impl Writeable<'tcx, M::Provenance>,
@@ -682,15 +683,7 @@ where
682683

683684
match value {
684685
Immediate::Scalar(scalar) => {
685-
let Abi::Scalar(s) = layout.abi else {
686-
span_bug!(
687-
self.cur_span(),
688-
"write_immediate_to_mplace: invalid Scalar layout: {layout:#?}",
689-
)
690-
};
691-
let size = s.size(&tcx);
692-
assert_eq!(size, layout.size, "abi::Scalar size does not match layout size");
693-
alloc.write_scalar(alloc_range(Size::ZERO, size), scalar)
686+
alloc.write_scalar(alloc_range(Size::ZERO, scalar.size()), scalar)
694687
}
695688
Immediate::ScalarPair(a_val, b_val) => {
696689
let Abi::ScalarPair(a, b) = layout.abi else {
@@ -700,16 +693,15 @@ where
700693
layout
701694
)
702695
};
703-
let (a_size, b_size) = (a.size(&tcx), b.size(&tcx));
704-
let b_offset = a_size.align_to(b.align(&tcx).abi);
696+
let b_offset = a.size(&tcx).align_to(b.align(&tcx).abi);
705697
assert!(b_offset.bytes() > 0); // in `operand_field` we use the offset to tell apart the fields
706698

707699
// It is tempting to verify `b_offset` against `layout.fields.offset(1)`,
708700
// but that does not work: We could be a newtype around a pair, then the
709701
// fields do not match the `ScalarPair` components.
710702

711-
alloc.write_scalar(alloc_range(Size::ZERO, a_size), a_val)?;
712-
alloc.write_scalar(alloc_range(b_offset, b_size), b_val)
703+
alloc.write_scalar(alloc_range(Size::ZERO, a_val.size()), a_val)?;
704+
alloc.write_scalar(alloc_range(b_offset, b_val.size()), b_val)
713705
}
714706
Immediate::Uninit => alloc.write_uninit(),
715707
}
@@ -734,6 +726,26 @@ where
734726
Ok(())
735727
}
736728

729+
/// Remove all provenance in the given place.
730+
pub fn clear_provenance(
731+
&mut self,
732+
dest: &impl Writeable<'tcx, M::Provenance>,
733+
) -> InterpResult<'tcx> {
734+
match self.as_mplace_or_mutable_local(&dest.to_place())? {
735+
Right((local_val, _local_layout)) => {
736+
local_val.clear_provenance()?;
737+
}
738+
Left(mplace) => {
739+
let Some(mut alloc) = self.get_place_alloc_mut(&mplace)? else {
740+
// Zero-sized access
741+
return Ok(());
742+
};
743+
alloc.clear_provenance()?;
744+
}
745+
}
746+
Ok(())
747+
}
748+
737749
/// Copies the data from an operand to a place.
738750
/// The layouts of the `src` and `dest` may disagree.
739751
/// Does not perform validation of the destination.
@@ -787,23 +799,30 @@ where
787799
allow_transmute: bool,
788800
validate_dest: bool,
789801
) -> InterpResult<'tcx> {
790-
// Generally for transmutation, data must be valid both at the old and new type.
791-
// But if the types are the same, the 2nd validation below suffices.
792-
if src.layout().ty != dest.layout().ty && M::enforce_validity(self, src.layout()) {
793-
self.validate_operand(
794-
&src.to_op(self)?,
795-
M::enforce_validity_recursively(self, src.layout()),
796-
)?;
797-
}
802+
// These are technically *two* typed copies: `src` is a not-yet-loaded value,
803+
// so we're going a typed copy at `src` type from there to some intermediate storage.
804+
// And then we're doing a second typed copy from that intermediate storage to `dest`.
805+
// But as an optimization, we only make a single direct copy here.
798806

799807
// Do the actual copy.
800808
self.copy_op_no_validate(src, dest, allow_transmute)?;
801809

802810
if validate_dest && M::enforce_validity(self, dest.layout()) {
803-
// Data got changed, better make sure it matches the type!
811+
let dest = dest.to_place();
812+
// Given that there were two typed copies, we have to ensure this is valid at both types,
813+
// and we have to ensure this loses provenance and padding according to both types.
814+
// But if the types are identical, we only do one pass.
815+
if src.layout().ty != dest.layout().ty {
816+
self.validate_operand(
817+
&dest.transmute(src.layout(), self)?,
818+
M::enforce_validity_recursively(self, src.layout()),
819+
/*reset_provenance*/ true,
820+
)?;
821+
}
804822
self.validate_operand(
805-
&dest.to_op(self)?,
823+
&dest,
806824
M::enforce_validity_recursively(self, dest.layout()),
825+
/*reset_provenance*/ true,
807826
)?;
808827
}
809828

0 commit comments

Comments
 (0)