Skip to content

Commit 6cd26cc

Browse files
authored
Merge pull request #458 from reitermarkus/treiber-64
Implement `pool` for any 32/64-bit architecture that supports the corresponding atomics.
2 parents 483862b + d4cc41e commit 6cd26cc

File tree

7 files changed

+167
-62
lines changed

7 files changed

+167
-62
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
1717
- Added `Extend` impls for `Deque`.
1818
- Added `Deque::make_contiguous`.
1919
- Added `VecView`, the `!Sized` version of `Vec`.
20+
- Added pool implementations for 64-bit architectures.
2021

2122
### Changed
2223

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ defmt-03 = ["dep:defmt"]
3939
# Enable larger MPMC sizes.
4040
mpmc_large = []
4141

42+
nightly = []
43+
4244
[dependencies]
4345
portable-atomic = { version = "1.0", optional = true }
4446
hash32 = "0.3.0"
@@ -47,7 +49,7 @@ ufmt-write = { version = "0.1", optional = true }
4749
defmt = { version = ">=0.2.0,<0.4", optional = true }
4850

4951
# for the pool module
50-
[target.'cfg(any(target_arch = "arm", target_arch = "x86"))'.dependencies]
52+
[target.'cfg(any(target_arch = "arm", target_pointer_width = "32", target_pointer_width = "64"))'.dependencies]
5153
stable_deref_trait = { version = "1", default-features = false }
5254

5355
[dev-dependencies]

src/lib.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343
//!
4444
//! List of currently implemented data structures:
4545
#![cfg_attr(
46-
any(arm_llsc, target_arch = "x86"),
46+
any(arm_llsc, target_pointer_width = "32", target_pointer_width = "64"),
4747
doc = "- [`Arc`](pool::arc::Arc) -- like `std::sync::Arc` but backed by a lock-free memory pool rather than `#[global_allocator]`"
4848
)]
4949
#![cfg_attr(
50-
any(arm_llsc, target_arch = "x86"),
50+
any(arm_llsc, target_pointer_width = "32", target_pointer_width = "64"),
5151
doc = "- [`Box`](pool::boxed::Box) -- like `std::boxed::Box` but backed by a lock-free memory pool rather than `#[global_allocator]`"
5252
)]
5353
//! - [`BinaryHeap`] -- priority queue
@@ -57,7 +57,7 @@
5757
//! - [`IndexSet`] -- hash set
5858
//! - [`LinearMap`]
5959
#![cfg_attr(
60-
any(arm_llsc, target_arch = "x86"),
60+
any(arm_llsc, target_pointer_width = "32", target_pointer_width = "64"),
6161
doc = "- [`Object`](pool::object::Object) -- objects managed by an object pool"
6262
)]
6363
//! - [`sorted_linked_list::SortedLinkedList`]
@@ -76,6 +76,14 @@
7676
#![cfg_attr(docsrs, feature(doc_cfg), feature(doc_auto_cfg))]
7777
#![cfg_attr(not(test), no_std)]
7878
#![deny(missing_docs)]
79+
#![cfg_attr(
80+
all(
81+
feature = "nightly",
82+
target_pointer_width = "64",
83+
target_has_atomic = "128"
84+
),
85+
feature(integer_atomics)
86+
)]
7987

8088
pub use binary_heap::BinaryHeap;
8189
pub use deque::Deque;
@@ -125,7 +133,20 @@ mod defmt;
125133
all(not(feature = "mpmc_large"), target_has_atomic = "8")
126134
))]
127135
pub mod mpmc;
128-
#[cfg(any(arm_llsc, target_arch = "x86"))]
136+
#[cfg(any(
137+
arm_llsc,
138+
all(
139+
target_pointer_width = "32",
140+
any(target_has_atomic = "64", feature = "portable-atomic")
141+
),
142+
all(
143+
target_pointer_width = "64",
144+
any(
145+
all(target_has_atomic = "128", feature = "nightly"),
146+
feature = "portable-atomic"
147+
)
148+
)
149+
))]
129150
pub mod pool;
130151
pub mod sorted_linked_list;
131152
#[cfg(any(

src/pool/arc.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,15 @@ use core::{
7272
hash::{Hash, Hasher},
7373
mem::{ManuallyDrop, MaybeUninit},
7474
ops, ptr,
75-
sync::atomic::{self, AtomicUsize, Ordering},
7675
};
7776

77+
#[cfg(not(feature = "portable-atomic"))]
78+
use core::sync::atomic;
79+
#[cfg(feature = "portable-atomic")]
80+
use portable_atomic as atomic;
81+
82+
use atomic::{AtomicUsize, Ordering};
83+
7884
use super::treiber::{NonNullPtr, Stack, UnionNode};
7985

8086
/// Creates a new `ArcPool` singleton with the given `$name` that manages the specified `$data_type`

src/pool/treiber.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use core::mem::ManuallyDrop;
22

3-
#[cfg_attr(target_arch = "x86", path = "treiber/cas.rs")]
3+
#[cfg_attr(not(arm_llsc), path = "treiber/cas.rs")]
44
#[cfg_attr(arm_llsc, path = "treiber/llsc.rs")]
55
mod impl_;
66

src/pool/treiber/cas.rs

Lines changed: 123 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,56 @@
1-
use core::{
2-
marker::PhantomData,
3-
num::{NonZeroU32, NonZeroU64},
4-
ptr::NonNull,
5-
sync::atomic::{AtomicU64, Ordering},
6-
};
1+
use core::{marker::PhantomData, ptr::NonNull};
2+
3+
#[cfg(not(feature = "portable-atomic"))]
4+
use core::sync::atomic;
5+
#[cfg(feature = "portable-atomic")]
6+
use portable_atomic as atomic;
7+
8+
use atomic::Ordering;
79

810
use super::{Node, Stack};
911

12+
#[cfg(target_pointer_width = "32")]
13+
mod types {
14+
use super::atomic;
15+
16+
pub type Inner = u64;
17+
pub type InnerAtomic = atomic::AtomicU64;
18+
pub type InnerNonZero = core::num::NonZeroU64;
19+
20+
pub type Tag = core::num::NonZeroU32;
21+
pub type Address = u32;
22+
}
23+
24+
#[cfg(target_pointer_width = "64")]
25+
mod types {
26+
use super::atomic;
27+
28+
pub type Inner = u128;
29+
pub type InnerAtomic = atomic::AtomicU128;
30+
pub type InnerNonZero = core::num::NonZeroU128;
31+
32+
pub type Tag = core::num::NonZeroU64;
33+
pub type Address = u64;
34+
}
35+
36+
use types::*;
37+
1038
pub struct AtomicPtr<N>
1139
where
1240
N: Node,
1341
{
14-
inner: AtomicU64,
42+
inner: InnerAtomic,
1543
_marker: PhantomData<*mut N>,
1644
}
1745

1846
impl<N> AtomicPtr<N>
1947
where
2048
N: Node,
2149
{
50+
#[inline]
2251
pub const fn null() -> Self {
2352
Self {
24-
inner: AtomicU64::new(0),
53+
inner: InnerAtomic::new(0),
2554
_marker: PhantomData,
2655
}
2756
}
@@ -35,37 +64,38 @@ where
3564
) -> Result<(), Option<NonNullPtr<N>>> {
3665
self.inner
3766
.compare_exchange_weak(
38-
current
39-
.map(|pointer| pointer.into_u64())
40-
.unwrap_or_default(),
41-
new.map(|pointer| pointer.into_u64()).unwrap_or_default(),
67+
current.map(NonNullPtr::into_inner).unwrap_or_default(),
68+
new.map(NonNullPtr::into_inner).unwrap_or_default(),
4269
success,
4370
failure,
4471
)
4572
.map(drop)
46-
.map_err(NonNullPtr::from_u64)
73+
.map_err(|value| {
74+
// SAFETY: `value` cam from a `NonNullPtr::into_inner` call.
75+
unsafe { NonNullPtr::from_inner(value) }
76+
})
4777
}
4878

79+
#[inline]
4980
fn load(&self, order: Ordering) -> Option<NonNullPtr<N>> {
50-
NonZeroU64::new(self.inner.load(order)).map(|inner| NonNullPtr {
51-
inner,
81+
Some(NonNullPtr {
82+
inner: InnerNonZero::new(self.inner.load(order))?,
5283
_marker: PhantomData,
5384
})
5485
}
5586

87+
#[inline]
5688
fn store(&self, value: Option<NonNullPtr<N>>, order: Ordering) {
57-
self.inner.store(
58-
value.map(|pointer| pointer.into_u64()).unwrap_or_default(),
59-
order,
60-
)
89+
self.inner
90+
.store(value.map(NonNullPtr::into_inner).unwrap_or_default(), order)
6191
}
6292
}
6393

6494
pub struct NonNullPtr<N>
6595
where
6696
N: Node,
6797
{
68-
inner: NonZeroU64,
98+
inner: InnerNonZero,
6999
_marker: PhantomData<*mut N>,
70100
}
71101

@@ -84,65 +114,72 @@ impl<N> NonNullPtr<N>
84114
where
85115
N: Node,
86116
{
117+
#[inline]
87118
pub fn as_ptr(&self) -> *mut N {
88119
self.inner.get() as *mut N
89120
}
90121

91-
pub fn from_static_mut_ref(ref_: &'static mut N) -> NonNullPtr<N> {
92-
let non_null = NonNull::from(ref_);
93-
Self::from_non_null(non_null)
122+
#[inline]
123+
pub fn from_static_mut_ref(reference: &'static mut N) -> NonNullPtr<N> {
124+
// SAFETY: `reference` is a static mutable reference, i.e. a valid pointer.
125+
unsafe { Self::new_unchecked(initial_tag(), NonNull::from(reference)) }
94126
}
95127

96-
fn from_non_null(ptr: NonNull<N>) -> Self {
97-
let address = ptr.as_ptr() as u32;
98-
let tag = initial_tag().get();
99-
100-
let value = (u64::from(tag) << 32) | u64::from(address);
128+
/// # Safety
129+
///
130+
/// - `ptr` must be a valid pointer.
131+
#[inline]
132+
unsafe fn new_unchecked(tag: Tag, ptr: NonNull<N>) -> Self {
133+
let value =
134+
(Inner::from(tag.get()) << Address::BITS) | Inner::from(ptr.as_ptr() as Address);
101135

102136
Self {
103-
inner: unsafe { NonZeroU64::new_unchecked(value) },
137+
// SAFETY: `value` is constructed from a `Tag` which is non-zero and half the
138+
// size of the `InnerNonZero` type, and a `NonNull<N>` pointer.
139+
inner: unsafe { InnerNonZero::new_unchecked(value) },
104140
_marker: PhantomData,
105141
}
106142
}
107143

108-
fn from_u64(value: u64) -> Option<Self> {
109-
NonZeroU64::new(value).map(|inner| Self {
110-
inner,
144+
/// # Safety
145+
///
146+
/// - `value` must come from a `Self::into_inner` call.
147+
#[inline]
148+
unsafe fn from_inner(value: Inner) -> Option<Self> {
149+
Some(Self {
150+
inner: InnerNonZero::new(value)?,
111151
_marker: PhantomData,
112152
})
113153
}
114154

155+
#[inline]
115156
fn non_null(&self) -> NonNull<N> {
116-
unsafe { NonNull::new_unchecked(self.inner.get() as *mut N) }
157+
// SAFETY: `Self` can only be constructed using a `NonNull<N>`.
158+
unsafe { NonNull::new_unchecked(self.as_ptr()) }
117159
}
118160

119-
fn tag(&self) -> NonZeroU32 {
120-
unsafe { NonZeroU32::new_unchecked((self.inner.get() >> 32) as u32) }
121-
}
122-
123-
fn into_u64(self) -> u64 {
161+
#[inline]
162+
fn into_inner(self) -> Inner {
124163
self.inner.get()
125164
}
126165

127-
fn increase_tag(&mut self) {
128-
let address = self.as_ptr() as u32;
129-
130-
let new_tag = self
131-
.tag()
132-
.get()
133-
.checked_add(1)
134-
.map(|val| unsafe { NonZeroU32::new_unchecked(val) })
135-
.unwrap_or_else(initial_tag)
136-
.get();
166+
#[inline]
167+
fn tag(&self) -> Tag {
168+
// SAFETY: `self.inner` was constructed from a non-zero `Tag`.
169+
unsafe { Tag::new_unchecked((self.inner.get() >> Address::BITS) as Address) }
170+
}
137171

138-
let value = (u64::from(new_tag) << 32) | u64::from(address);
172+
fn increment_tag(&mut self) {
173+
let new_tag = self.tag().checked_add(1).unwrap_or_else(initial_tag);
139174

140-
self.inner = unsafe { NonZeroU64::new_unchecked(value) };
175+
// SAFETY: `self.non_null()` is a valid pointer.
176+
*self = unsafe { Self::new_unchecked(new_tag, self.non_null()) };
141177
}
142178
}
143179

144-
fn initial_tag() -> NonZeroU32 {
145-
unsafe { NonZeroU32::new_unchecked(1) }
180+
#[inline]
181+
const fn initial_tag() -> Tag {
182+
Tag::MIN
146183
}
147184

148185
pub unsafe fn push<N>(stack: &Stack<N>, new_top: NonNullPtr<N>)
@@ -184,7 +221,40 @@ where
184221
.compare_and_exchange_weak(Some(top), next, Ordering::Release, Ordering::Relaxed)
185222
.is_ok()
186223
{
187-
top.increase_tag();
224+
// Prevent the ABA problem (https://en.wikipedia.org/wiki/Treiber_stack#Correctness).
225+
//
226+
// Without this, the following would be possible:
227+
//
228+
// | Thread 1 | Thread 2 | Stack |
229+
// |-------------------------------|-------------------------|------------------------------|
230+
// | push((1, 1)) | | (1, 1) |
231+
// | push((1, 2)) | | (1, 2) -> (1, 1) |
232+
// | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
233+
// | | p = try_pop() // (1, 2) | (1, 1) |
234+
// | | push((1, 3)) | (1, 3) -> (1, 1) |
235+
// | | push(p) | (1, 2) -> (1, 3) -> (1, 1) |
236+
// | try_pop()::cas(p, p.next) | | (1, 1) |
237+
//
238+
// As can be seen, the `cas` operation succeeds, wrongly removing pointer `3` from the stack.
239+
//
240+
// By incrementing the tag before returning the pointer, it cannot be pushed again with the,
241+
// same tag, preventing the `try_pop()::cas(p, p.next)` operation from succeeding.
242+
//
243+
// With this fix, `try_pop()` in thread 2 returns `(2, 2)` and the comparison between
244+
// `(1, 2)` and `(2, 2)` fails, restarting the loop and correctly removing the new top:
245+
//
246+
// | Thread 1 | Thread 2 | Stack |
247+
// |-------------------------------|-------------------------|------------------------------|
248+
// | push((1, 1)) | | (1, 1) |
249+
// | push((1, 2)) | | (1, 2) -> (1, 1) |
250+
// | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
251+
// | | p = try_pop() // (2, 2) | (1, 1) |
252+
// | | push((1, 3)) | (1, 3) -> (1, 1) |
253+
// | | push(p) | (2, 2) -> (1, 3) -> (1, 1) |
254+
// | try_pop()::cas(p, p.next) | | (2, 2) -> (1, 3) -> (1, 1) |
255+
// | p = try_pop()::load // (2, 2) | | (2, 2) -> (1, 3) -> (1, 1) |
256+
// | try_pop()::cas(p, p.next) | | (1, 3) -> (1, 1) |
257+
top.increment_tag();
188258

189259
return Some(top);
190260
}

0 commit comments

Comments
 (0)