Skip to content

Commit 4e4607b

Browse files
committed
Auto merge of #2341 - RalfJung:rustup, r=RalfJung
rustup; ptr atomics Adds support for the operations added in rust-lang/rust#96935. I made the pointer-binops always return the provenance of the *left* argument; `@thomcc` I hope that is what you intended. I have honestly no idea if it has anything to do with what LLVM does... I also simplified our pointer comparison code while I was at it -- now that *all* comparison operators support wide pointers, we can unify those branches.
2 parents 8c71148 + d5f1c26 commit 4e4607b

File tree

5 files changed

+107
-51
lines changed

5 files changed

+107
-51
lines changed

rust-version

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
049308cf8b48e9d67e54d6d0b01c10c79d1efc3a
1+
7665c3543079ebc3710b676d0fd6951bedfd4b29

src/helpers.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
195195

196196
/// Test if this pointer equals 0.
197197
fn ptr_is_null(&self, ptr: Pointer<Option<Tag>>) -> InterpResult<'tcx, bool> {
198-
let this = self.eval_context_ref();
199-
let null = Scalar::null_ptr(this);
200-
this.ptr_eq(Scalar::from_maybe_pointer(ptr, this), null)
198+
Ok(ptr.addr().bytes() == 0)
201199
}
202200

203201
/// Get the `Place` for a local

src/operator.rs

+32-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use log::trace;
22

33
use rustc_middle::{mir, ty::Ty};
4+
use rustc_target::abi::Size;
45

56
use crate::*;
67

@@ -11,8 +12,6 @@ pub trait EvalContextExt<'tcx> {
1112
left: &ImmTy<'tcx, Tag>,
1213
right: &ImmTy<'tcx, Tag>,
1314
) -> InterpResult<'tcx, (Scalar<Tag>, bool, Ty<'tcx>)>;
14-
15-
fn ptr_eq(&self, left: Scalar<Tag>, right: Scalar<Tag>) -> InterpResult<'tcx, bool>;
1615
}
1716

1817
impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriEvalContext<'mir, 'tcx> {
@@ -27,23 +26,8 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriEvalContext<'mir, 'tcx> {
2726
trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right);
2827

2928
Ok(match bin_op {
30-
Eq | Ne => {
31-
// This supports fat pointers.
32-
#[rustfmt::skip]
33-
let eq = match (**left, **right) {
34-
(Immediate::Scalar(left), Immediate::Scalar(right)) => {
35-
self.ptr_eq(left.check_init()?, right.check_init()?)?
36-
}
37-
(Immediate::ScalarPair(left1, left2), Immediate::ScalarPair(right1, right2)) => {
38-
self.ptr_eq(left1.check_init()?, right1.check_init()?)?
39-
&& self.ptr_eq(left2.check_init()?, right2.check_init()?)?
40-
}
41-
_ => bug!("Type system should not allow comparing Scalar with ScalarPair"),
42-
};
43-
(Scalar::from_bool(if bin_op == Eq { eq } else { !eq }), false, self.tcx.types.bool)
44-
}
45-
46-
Lt | Le | Gt | Ge => {
29+
Eq | Ne | Lt | Le | Gt | Ge => {
30+
assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for`
4731
let size = self.pointer_size();
4832
// Just compare the bits. ScalarPairs are compared lexicographically.
4933
// We thus always compare pairs and simply fill scalars up with 0.
@@ -58,35 +42,49 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriEvalContext<'mir, 'tcx> {
5842
(r1.check_init()?.to_bits(size)?, r2.check_init()?.to_bits(size)?),
5943
};
6044
let res = match bin_op {
45+
Eq => left == right,
46+
Ne => left != right,
6147
Lt => left < right,
6248
Le => left <= right,
6349
Gt => left > right,
6450
Ge => left >= right,
65-
_ => bug!("We already established it has to be one of these operators."),
51+
_ => bug!(),
6652
};
6753
(Scalar::from_bool(res), false, self.tcx.types.bool)
6854
}
6955

7056
Offset => {
57+
assert!(left.layout.ty.is_unsafe_ptr());
58+
let ptr = self.scalar_to_ptr(left.to_scalar()?)?;
59+
let offset = right.to_scalar()?.to_machine_isize(self)?;
60+
7161
let pointee_ty =
7262
left.layout.ty.builtin_deref(true).expect("Offset called on non-ptr type").ty;
73-
let ptr = self.ptr_offset_inbounds(
74-
self.scalar_to_ptr(left.to_scalar()?)?,
75-
pointee_ty,
76-
right.to_scalar()?.to_machine_isize(self)?,
77-
)?;
63+
let ptr = self.ptr_offset_inbounds(ptr, pointee_ty, offset)?;
7864
(Scalar::from_maybe_pointer(ptr, self), false, left.layout.ty)
7965
}
8066

81-
_ => bug!("Invalid operator on pointers: {:?}", bin_op),
82-
})
83-
}
67+
// Some more operations are possible with atomics.
68+
// The return value always has the provenance of the *left* operand.
69+
Add | Sub | BitOr | BitAnd | BitXor => {
70+
assert!(left.layout.ty.is_unsafe_ptr());
71+
assert!(right.layout.ty.is_unsafe_ptr());
72+
let ptr = self.scalar_to_ptr(left.to_scalar()?)?;
73+
// We do the actual operation with usize-typed scalars.
74+
let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize);
75+
let right = ImmTy::from_uint(
76+
right.to_scalar()?.to_machine_usize(self)?,
77+
self.machine.layouts.usize,
78+
);
79+
let (result, overflowing, _ty) =
80+
self.overflowing_binary_op(bin_op, &left, &right)?;
81+
// Construct a new pointer with the provenance of `ptr` (the LHS).
82+
let result_ptr =
83+
Pointer::new(ptr.provenance, Size::from_bytes(result.to_machine_usize(self)?));
84+
(Scalar::from_maybe_pointer(result_ptr, self), overflowing, left.layout.ty)
85+
}
8486

85-
fn ptr_eq(&self, left: Scalar<Tag>, right: Scalar<Tag>) -> InterpResult<'tcx, bool> {
86-
let size = self.pointer_size();
87-
// Just compare the integers.
88-
let left = left.to_bits(size)?;
89-
let right = right.to_bits(size)?;
90-
Ok(left == right)
87+
_ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op),
88+
})
9189
}
9290
}

src/shims/intrinsics.rs

+19-12
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
314314
ty::Float(FloatTy::F64) =>
315315
this.float_to_int_unchecked(val.to_scalar()?.to_f64()?, dest.layout.ty)?,
316316
_ =>
317-
bug!(
317+
span_bug!(
318+
this.cur_span(),
318319
"`float_to_int_unchecked` called with non-float input type {:?}",
319320
val.layout.ty
320321
),
@@ -371,7 +372,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
371372
Op::Abs => {
372373
// Works for f32 and f64.
373374
let ty::Float(float_ty) = op.layout.ty.kind() else {
374-
bug!("{} operand is not a float", intrinsic_name)
375+
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
375376
};
376377
let op = op.to_scalar()?;
377378
match float_ty {
@@ -381,7 +382,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
381382
}
382383
Op::HostOp(host_op) => {
383384
let ty::Float(float_ty) = op.layout.ty.kind() else {
384-
bug!("{} operand is not a float", intrinsic_name)
385+
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
385386
};
386387
// FIXME using host floats
387388
match float_ty {
@@ -546,7 +547,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
546547

547548
// Works for f32 and f64.
548549
let ty::Float(float_ty) = dest.layout.ty.kind() else {
549-
bug!("{} operand is not a float", intrinsic_name)
550+
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
550551
};
551552
let val = match float_ty {
552553
FloatTy::F32 =>
@@ -763,7 +764,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
763764

764765
// `index` is an array, not a SIMD type
765766
let ty::Array(_, index_len) = index.layout.ty.kind() else {
766-
bug!("simd_shuffle index argument has non-array type {}", index.layout.ty)
767+
span_bug!(this.cur_span(), "simd_shuffle index argument has non-array type {}", index.layout.ty)
767768
};
768769
let index_len = index_len.eval_usize(*this.tcx, this.param_env());
769770

@@ -785,10 +786,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
785786
&this.mplace_index(&right, src_index - left_len)?.into(),
786787
)?
787788
} else {
788-
bug!(
789-
"simd_shuffle index {} is out of bounds for 2 vectors of size {}",
790-
src_index,
791-
left_len
789+
span_bug!(
790+
this.cur_span(),
791+
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
792792
);
793793
};
794794
this.write_immediate(*val, &dest.into())?;
@@ -1187,8 +1187,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
11871187
let [place, rhs] = check_arg_count(args)?;
11881188
let place = this.deref_operand(place)?;
11891189

1190-
if !place.layout.ty.is_integral() {
1191-
bug!("Atomic arithmetic operations only work on integer types");
1190+
if !place.layout.ty.is_integral() && !place.layout.ty.is_unsafe_ptr() {
1191+
span_bug!(
1192+
this.cur_span(),
1193+
"atomic arithmetic operations only work on integer and raw pointer types",
1194+
);
11921195
}
11931196
let rhs = this.read_immediate(rhs)?;
11941197

@@ -1355,7 +1358,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
13551358
}
13561359
}
13571360
// Nothing else
1358-
_ => bug!("`float_to_int_unchecked` called with non-int output type {dest_ty:?}"),
1361+
_ =>
1362+
span_bug!(
1363+
this.cur_span(),
1364+
"`float_to_int_unchecked` called with non-int output type {dest_ty:?}"
1365+
),
13591366
})
13601367
}
13611368
}

tests/pass/atomic.rs

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
use std::sync::atomic::{compiler_fence, fence, AtomicBool, AtomicIsize, AtomicU64, Ordering::*};
1+
// compile-flags: -Zmiri-strict-provenance
2+
#![feature(strict_provenance, strict_provenance_atomic_ptr)]
3+
use std::sync::atomic::{
4+
compiler_fence, fence, AtomicBool, AtomicIsize, AtomicPtr, AtomicU64, Ordering::*,
5+
};
26

37
fn main() {
48
atomic_bool();
59
atomic_all_ops();
610
atomic_u64();
711
atomic_fences();
12+
atomic_ptr();
813
weak_sometimes_fails();
914
}
1015

@@ -130,6 +135,54 @@ fn atomic_fences() {
130135
compiler_fence(AcqRel);
131136
}
132137

138+
fn atomic_ptr() {
139+
use std::ptr;
140+
let array: Vec<i32> = (0..100).into_iter().collect(); // a target to point to, to test provenance things
141+
let x = array.as_ptr() as *mut i32;
142+
143+
let ptr = AtomicPtr::<i32>::new(ptr::null_mut());
144+
assert!(ptr.load(Relaxed).addr() == 0);
145+
ptr.store(ptr::invalid_mut(13), SeqCst);
146+
assert!(ptr.swap(x, Relaxed).addr() == 13);
147+
unsafe { assert!(*ptr.load(Acquire) == 0) };
148+
149+
// comparison ignores provenance
150+
assert_eq!(
151+
ptr.compare_exchange(
152+
(&mut 0 as *mut i32).with_addr(x.addr()),
153+
ptr::invalid_mut(0),
154+
SeqCst,
155+
SeqCst
156+
)
157+
.unwrap()
158+
.addr(),
159+
x.addr(),
160+
);
161+
assert_eq!(
162+
ptr.compare_exchange(
163+
(&mut 0 as *mut i32).with_addr(x.addr()),
164+
ptr::invalid_mut(0),
165+
SeqCst,
166+
SeqCst
167+
)
168+
.unwrap_err()
169+
.addr(),
170+
0,
171+
);
172+
ptr.store(x, Relaxed);
173+
174+
assert_eq!(ptr.fetch_ptr_add(13, AcqRel).addr(), x.addr());
175+
unsafe { assert_eq!(*ptr.load(SeqCst), 13) }; // points to index 13 now
176+
assert_eq!(ptr.fetch_ptr_sub(4, AcqRel).addr(), x.addr() + 13 * 4);
177+
unsafe { assert_eq!(*ptr.load(SeqCst), 9) };
178+
assert_eq!(ptr.fetch_or(3, AcqRel).addr(), x.addr() + 9 * 4); // ptr is 4-aligned, so set the last 2 bits
179+
assert_eq!(ptr.fetch_and(!3, AcqRel).addr(), (x.addr() + 9 * 4) | 3); // and unset them again
180+
unsafe { assert_eq!(*ptr.load(SeqCst), 9) };
181+
assert_eq!(ptr.fetch_xor(0xdeadbeef, AcqRel).addr(), x.addr() + 9 * 4);
182+
assert_eq!(ptr.fetch_xor(0xdeadbeef, AcqRel).addr(), (x.addr() + 9 * 4) ^ 0xdeadbeef);
183+
unsafe { assert_eq!(*ptr.load(SeqCst), 9) }; // after XORing twice with the same thing, we get our ptr back
184+
}
185+
133186
fn weak_sometimes_fails() {
134187
let atomic = AtomicBool::new(false);
135188
let tries = 100;

0 commit comments

Comments
 (0)