1
- use core:: {
2
- marker:: PhantomData ,
3
- num:: { NonZeroU32 , NonZeroU64 } ,
4
- ptr:: NonNull ,
5
- sync:: atomic:: { AtomicU64 , Ordering } ,
6
- } ;
1
+ use core:: { marker:: PhantomData , ptr:: NonNull } ;
2
+
3
+ #[ cfg( not( feature = "portable-atomic" ) ) ]
4
+ use core:: sync:: atomic;
5
+ #[ cfg( feature = "portable-atomic" ) ]
6
+ use portable_atomic as atomic;
7
+
8
+ use atomic:: Ordering ;
7
9
8
10
use super :: { Node , Stack } ;
9
11
12
+ #[ cfg( target_pointer_width = "32" ) ]
13
+ mod types {
14
+ use super :: atomic;
15
+
16
+ pub type Inner = u64 ;
17
+ pub type InnerAtomic = atomic:: AtomicU64 ;
18
+ pub type InnerNonZero = core:: num:: NonZeroU64 ;
19
+
20
+ pub type Tag = core:: num:: NonZeroU32 ;
21
+ pub type Address = u32 ;
22
+ }
23
+
24
+ #[ cfg( target_pointer_width = "64" ) ]
25
+ mod types {
26
+ use super :: atomic;
27
+
28
+ pub type Inner = u128 ;
29
+ pub type InnerAtomic = atomic:: AtomicU128 ;
30
+ pub type InnerNonZero = core:: num:: NonZeroU128 ;
31
+
32
+ pub type Tag = core:: num:: NonZeroU64 ;
33
+ pub type Address = u64 ;
34
+ }
35
+
36
+ use types:: * ;
37
+
10
38
pub struct AtomicPtr < N >
11
39
where
12
40
N : Node ,
13
41
{
14
- inner : AtomicU64 ,
42
+ inner : InnerAtomic ,
15
43
_marker : PhantomData < * mut N > ,
16
44
}
17
45
18
46
impl < N > AtomicPtr < N >
19
47
where
20
48
N : Node ,
21
49
{
50
+ #[ inline]
22
51
pub const fn null ( ) -> Self {
23
52
Self {
24
- inner : AtomicU64 :: new ( 0 ) ,
53
+ inner : InnerAtomic :: new ( 0 ) ,
25
54
_marker : PhantomData ,
26
55
}
27
56
}
@@ -35,37 +64,38 @@ where
35
64
) -> Result < ( ) , Option < NonNullPtr < N > > > {
36
65
self . inner
37
66
. compare_exchange_weak (
38
- current
39
- . map ( |pointer| pointer. into_u64 ( ) )
40
- . unwrap_or_default ( ) ,
41
- new. map ( |pointer| pointer. into_u64 ( ) ) . unwrap_or_default ( ) ,
67
+ current. map ( NonNullPtr :: into_inner) . unwrap_or_default ( ) ,
68
+ new. map ( NonNullPtr :: into_inner) . unwrap_or_default ( ) ,
42
69
success,
43
70
failure,
44
71
)
45
72
. map ( drop)
46
- . map_err ( NonNullPtr :: from_u64)
73
+ . map_err ( |value| {
74
+ // SAFETY: `value` cam from a `NonNullPtr::into_inner` call.
75
+ unsafe { NonNullPtr :: from_inner ( value) }
76
+ } )
47
77
}
48
78
79
+ #[ inline]
49
80
fn load ( & self , order : Ordering ) -> Option < NonNullPtr < N > > {
50
- NonZeroU64 :: new ( self . inner . load ( order ) ) . map ( |inner| NonNullPtr {
51
- inner,
81
+ Some ( NonNullPtr {
82
+ inner : InnerNonZero :: new ( self . inner . load ( order ) ) ? ,
52
83
_marker : PhantomData ,
53
84
} )
54
85
}
55
86
87
+ #[ inline]
56
88
fn store ( & self , value : Option < NonNullPtr < N > > , order : Ordering ) {
57
- self . inner . store (
58
- value. map ( |pointer| pointer. into_u64 ( ) ) . unwrap_or_default ( ) ,
59
- order,
60
- )
89
+ self . inner
90
+ . store ( value. map ( NonNullPtr :: into_inner) . unwrap_or_default ( ) , order)
61
91
}
62
92
}
63
93
64
94
pub struct NonNullPtr < N >
65
95
where
66
96
N : Node ,
67
97
{
68
- inner : NonZeroU64 ,
98
+ inner : InnerNonZero ,
69
99
_marker : PhantomData < * mut N > ,
70
100
}
71
101
@@ -84,65 +114,72 @@ impl<N> NonNullPtr<N>
84
114
where
85
115
N : Node ,
86
116
{
117
+ #[ inline]
87
118
pub fn as_ptr ( & self ) -> * mut N {
88
119
self . inner . get ( ) as * mut N
89
120
}
90
121
91
- pub fn from_static_mut_ref ( ref_ : & ' static mut N ) -> NonNullPtr < N > {
92
- let non_null = NonNull :: from ( ref_) ;
93
- Self :: from_non_null ( non_null)
122
+ #[ inline]
123
+ pub fn from_static_mut_ref ( reference : & ' static mut N ) -> NonNullPtr < N > {
124
+ // SAFETY: `reference` is a static mutable reference, i.e. a valid pointer.
125
+ unsafe { Self :: new_unchecked ( initial_tag ( ) , NonNull :: from ( reference) ) }
94
126
}
95
127
96
- fn from_non_null ( ptr : NonNull < N > ) -> Self {
97
- let address = ptr. as_ptr ( ) as u32 ;
98
- let tag = initial_tag ( ) . get ( ) ;
99
-
100
- let value = ( u64:: from ( tag) << 32 ) | u64:: from ( address) ;
128
+ /// # Safety
129
+ ///
130
+ /// - `ptr` must be a valid pointer.
131
+ #[ inline]
132
+ unsafe fn new_unchecked ( tag : Tag , ptr : NonNull < N > ) -> Self {
133
+ let value =
134
+ ( Inner :: from ( tag. get ( ) ) << Address :: BITS ) | Inner :: from ( ptr. as_ptr ( ) as Address ) ;
101
135
102
136
Self {
103
- inner : unsafe { NonZeroU64 :: new_unchecked ( value) } ,
137
+ // SAFETY: `value` is constructed from a `Tag` which is non-zero and half the
138
+ // size of the `InnerNonZero` type, and a `NonNull<N>` pointer.
139
+ inner : unsafe { InnerNonZero :: new_unchecked ( value) } ,
104
140
_marker : PhantomData ,
105
141
}
106
142
}
107
143
108
- fn from_u64 ( value : u64 ) -> Option < Self > {
109
- NonZeroU64 :: new ( value) . map ( |inner| Self {
110
- inner,
144
+ /// # Safety
145
+ ///
146
+ /// - `value` must come from a `Self::into_inner` call.
147
+ #[ inline]
148
+ unsafe fn from_inner ( value : Inner ) -> Option < Self > {
149
+ Some ( Self {
150
+ inner : InnerNonZero :: new ( value) ?,
111
151
_marker : PhantomData ,
112
152
} )
113
153
}
114
154
155
+ #[ inline]
115
156
fn non_null ( & self ) -> NonNull < N > {
116
- unsafe { NonNull :: new_unchecked ( self . inner . get ( ) as * mut N ) }
157
+ // SAFETY: `Self` can only be constructed using a `NonNull<N>`.
158
+ unsafe { NonNull :: new_unchecked ( self . as_ptr ( ) ) }
117
159
}
118
160
119
- fn tag ( & self ) -> NonZeroU32 {
120
- unsafe { NonZeroU32 :: new_unchecked ( ( self . inner . get ( ) >> 32 ) as u32 ) }
121
- }
122
-
123
- fn into_u64 ( self ) -> u64 {
161
+ #[ inline]
162
+ fn into_inner ( self ) -> Inner {
124
163
self . inner . get ( )
125
164
}
126
165
127
- fn increase_tag ( & mut self ) {
128
- let address = self . as_ptr ( ) as u32 ;
129
-
130
- let new_tag = self
131
- . tag ( )
132
- . get ( )
133
- . checked_add ( 1 )
134
- . map ( |val| unsafe { NonZeroU32 :: new_unchecked ( val) } )
135
- . unwrap_or_else ( initial_tag)
136
- . get ( ) ;
166
+ #[ inline]
167
+ fn tag ( & self ) -> Tag {
168
+ // SAFETY: `self.inner` was constructed from a non-zero `Tag`.
169
+ unsafe { Tag :: new_unchecked ( ( self . inner . get ( ) >> Address :: BITS ) as Address ) }
170
+ }
137
171
138
- let value = ( u64:: from ( new_tag) << 32 ) | u64:: from ( address) ;
172
+ fn increment_tag ( & mut self ) {
173
+ let new_tag = self . tag ( ) . checked_add ( 1 ) . unwrap_or_else ( initial_tag) ;
139
174
140
- self . inner = unsafe { NonZeroU64 :: new_unchecked ( value) } ;
175
+ // SAFETY: `self.non_null()` is a valid pointer.
176
+ * self = unsafe { Self :: new_unchecked ( new_tag, self . non_null ( ) ) } ;
141
177
}
142
178
}
143
179
144
- fn initial_tag ( ) -> NonZeroU32 {
145
- unsafe { NonZeroU32 :: new_unchecked ( 1 ) }
180
+ #[ inline]
181
+ const fn initial_tag ( ) -> Tag {
182
+ Tag :: MIN
146
183
}
147
184
148
185
pub unsafe fn push < N > ( stack : & Stack < N > , new_top : NonNullPtr < N > )
@@ -184,7 +221,40 @@ where
184
221
. compare_and_exchange_weak ( Some ( top) , next, Ordering :: Release , Ordering :: Relaxed )
185
222
. is_ok ( )
186
223
{
187
- top. increase_tag ( ) ;
224
+ // Prevent the ABA problem (https://en.wikipedia.org/wiki/Treiber_stack#Correctness).
225
+ //
226
+ // Without this, the following would be possible:
227
+ //
228
+ // | Thread 1 | Thread 2 | Stack |
229
+ // |-------------------------------|-------------------------|------------------------------|
230
+ // | push((1, 1)) | | (1, 1) |
231
+ // | push((1, 2)) | | (1, 2) -> (1, 1) |
232
+ // | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
233
+ // | | p = try_pop() // (1, 2) | (1, 1) |
234
+ // | | push((1, 3)) | (1, 3) -> (1, 1) |
235
+ // | | push(p) | (1, 2) -> (1, 3) -> (1, 1) |
236
+ // | try_pop()::cas(p, p.next) | | (1, 1) |
237
+ //
238
+ // As can be seen, the `cas` operation succeeds, wrongly removing pointer `3` from the stack.
239
+ //
240
+ // By incrementing the tag before returning the pointer, it cannot be pushed again with the,
241
+ // same tag, preventing the `try_pop()::cas(p, p.next)` operation from succeeding.
242
+ //
243
+ // With this fix, `try_pop()` in thread 2 returns `(2, 2)` and the comparison between
244
+ // `(1, 2)` and `(2, 2)` fails, restarting the loop and correctly removing the new top:
245
+ //
246
+ // | Thread 1 | Thread 2 | Stack |
247
+ // |-------------------------------|-------------------------|------------------------------|
248
+ // | push((1, 1)) | | (1, 1) |
249
+ // | push((1, 2)) | | (1, 2) -> (1, 1) |
250
+ // | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
251
+ // | | p = try_pop() // (2, 2) | (1, 1) |
252
+ // | | push((1, 3)) | (1, 3) -> (1, 1) |
253
+ // | | push(p) | (2, 2) -> (1, 3) -> (1, 1) |
254
+ // | try_pop()::cas(p, p.next) | | (2, 2) -> (1, 3) -> (1, 1) |
255
+ // | p = try_pop()::load // (2, 2) | | (2, 2) -> (1, 3) -> (1, 1) |
256
+ // | try_pop()::cas(p, p.next) | | (1, 3) -> (1, 1) |
257
+ top. increment_tag ( ) ;
188
258
189
259
return Some ( top) ;
190
260
}
0 commit comments