Skip to content

Commit af9ce91

Browse files
authored
Simplify resumptions to avoid needing a resumption type and mark continuations as resumed directly in critical regions (apple#203)
1 parent f05e450 commit af9ce91

File tree

3 files changed

+61
-117
lines changed

3 files changed

+61
-117
lines changed

Sources/AsyncAlgorithms/AsyncChannel.swift

+30-40
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
111111
}
112112

113113
func cancelNext(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
114-
state.withCriticalRegion { state -> UnsafeContinuation<Element?, Never>? in
114+
state.withCriticalRegion { state in
115115
let continuation: UnsafeContinuation<Element?, Never>?
116116

117117
switch state.emission {
@@ -132,42 +132,38 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
132132
}
133133
}
134134

135-
return continuation
136-
}?.resume(returning: nil)
135+
continuation?.resume(returning: nil)
136+
}
137137
}
138138

139139
func next(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) async -> Element? {
140140
return await withUnsafeContinuation { (continuation: UnsafeContinuation<Element?, Never>) in
141141
var cancelled = false
142142
var terminal = false
143-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
143+
state.withCriticalRegion { state in
144144

145145
if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
146146
cancelled = true
147-
return nil
148147
}
149148

150149
switch state.emission {
151150
case .idle:
152151
state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)])
153-
return nil
154152
case .pending(var sends):
155153
let send = sends.removeFirst()
156154
if sends.count == 0 {
157155
state.emission = .idle
158156
} else {
159157
state.emission = .pending(sends)
160158
}
161-
return UnsafeResumption(continuation: send.continuation, success: continuation)
159+
send.continuation?.resume(returning: continuation)
162160
case .awaiting(var nexts):
163161
nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation))
164162
state.emission = .awaiting(nexts)
165-
return nil
166163
case .finished:
167164
terminal = true
168-
return nil
169165
}
170-
}?.resume()
166+
}
171167

172168
if cancelled || terminal {
173169
continuation.resume(returning: nil)
@@ -176,7 +172,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
176172
}
177173

178174
func cancelSend(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
179-
state.withCriticalRegion { state -> UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>? in
175+
state.withCriticalRegion { state in
180176
let continuation: UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>?
181177

182178
switch state.emission {
@@ -198,38 +194,37 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
198194
}
199195
}
200196

201-
return continuation
202-
}?.resume(returning: nil)
197+
continuation?.resume(returning: nil)
198+
}
203199
}
204200

205201
func send(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int, _ element: Element) async {
206-
let continuation = await withUnsafeContinuation { continuation in
207-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
202+
let continuation: UnsafeContinuation<Element?, Never>? = await withUnsafeContinuation { continuation in
203+
state.withCriticalRegion { state in
208204

209205
if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
210-
return UnsafeResumption(continuation: continuation, success: nil)
206+
continuation.resume(returning: nil)
207+
return
211208
}
212209

213210
switch state.emission {
214211
case .idle:
215212
state.emission = .pending([Pending(generation: generation, continuation: continuation)])
216-
return nil
217213
case .pending(var sends):
218214
sends.updateOrAppend(Pending(generation: generation, continuation: continuation))
219215
state.emission = .pending(sends)
220-
return nil
221216
case .awaiting(var nexts):
222217
let next = nexts.removeFirst().continuation
223218
if nexts.count == 0 {
224219
state.emission = .idle
225220
} else {
226221
state.emission = .awaiting(nexts)
227222
}
228-
return UnsafeResumption(continuation: continuation, success: next)
223+
continuation.resume(returning: next)
229224
case .finished:
230-
return UnsafeResumption(continuation: continuation, success: nil)
225+
continuation.resume(returning: nil)
231226
}
232-
}?.resume()
227+
}
233228
}
234229
continuation?.resume(returning: element)
235230
}
@@ -252,30 +247,25 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
252247
/// Send a finish to all awaiting iterations.
253248
/// All subsequent calls to `next(_:)` will resume immediately.
254249
public func finish() {
255-
let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet<Pending>, OrderedSet<Awaiting>) in
256-
let result: (OrderedSet<Pending>, OrderedSet<Awaiting>)
250+
state.withCriticalRegion { state in
257251

252+
defer { state.emission = .finished }
253+
258254
switch state.emission {
259-
case .idle:
260-
result = ([], [])
261-
case .pending(let nexts):
262-
result = (nexts, [])
255+
case .pending(let sends):
256+
for send in sends {
257+
send.continuation?.resume(returning: nil)
258+
}
263259
case .awaiting(let nexts):
264-
result = ([], nexts)
265-
case .finished:
266-
result = ([], [])
260+
for next in nexts {
261+
next.continuation?.resume(returning: nil)
262+
}
263+
default:
264+
break
267265
}
268-
269-
state.emission = .finished
270-
271-
return result
272-
}
273-
for send in sends {
274-
send.continuation?.resume(returning: nil)
275-
}
276-
for next in nexts {
277-
next.continuation?.resume(returning: nil)
278266
}
267+
268+
279269
}
280270

281271
/// Create an `Iterator` for iteration of an `AsyncChannel`

Sources/AsyncAlgorithms/AsyncThrowingChannel.swift

+31-40
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
119119
}
120120

121121
func cancelNext(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
122-
state.withCriticalRegion { state -> UnsafeContinuation<Element?, Error>? in
122+
state.withCriticalRegion { state in
123123
let continuation: UnsafeContinuation<Element?, Error>?
124124

125125
switch state.emission {
@@ -140,44 +140,41 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
140140
}
141141
}
142142

143-
return continuation
144-
}?.resume(returning: nil)
143+
continuation?.resume(returning: nil)
144+
}
145145
}
146146

147147
func next(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) async throws -> Element? {
148148
return try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation<Element?, Error>) in
149149
var cancelled = false
150150
var potentialTermination: Termination?
151151

152-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
152+
state.withCriticalRegion { state in
153153

154154
if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
155155
cancelled = true
156-
return nil
156+
return
157157
}
158158

159159
switch state.emission {
160160
case .idle:
161161
state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)])
162-
return nil
163162
case .pending(var sends):
164163
let send = sends.removeFirst()
165164
if sends.count == 0 {
166165
state.emission = .idle
167166
} else {
168167
state.emission = .pending(sends)
169168
}
170-
return UnsafeResumption(continuation: send.continuation, success: continuation)
169+
send.continuation?.resume(returning: continuation)
171170
case .awaiting(var nexts):
172171
nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation))
173172
state.emission = .awaiting(nexts)
174-
return nil
175173
case .terminated(let termination):
176174
potentialTermination = termination
177175
state.emission = .terminated(.finished)
178-
return nil
179176
}
180-
}?.resume()
177+
}
181178

182179
if cancelled {
183180
continuation.resume(returning: nil)
@@ -198,7 +195,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
198195
}
199196

200197
func cancelSend(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
201-
state.withCriticalRegion { state -> UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>? in
198+
state.withCriticalRegion { state in
202199
let continuation: UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>?
203200

204201
switch state.emission {
@@ -220,44 +217,43 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
220217
}
221218
}
222219

223-
return continuation
224-
}?.resume(returning: nil)
220+
continuation?.resume(returning: nil)
221+
}
225222
}
226223

227224
func send(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int, _ element: Element) async {
228225
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
229-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
226+
state.withCriticalRegion { state in
230227

231228
if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
232-
return UnsafeResumption(continuation: continuation, success: nil)
229+
continuation.resume(returning: nil)
230+
return
233231
}
234232

235233
switch state.emission {
236234
case .idle:
237235
state.emission = .pending([Pending(generation: generation, continuation: continuation)])
238-
return nil
239236
case .pending(var sends):
240237
sends.updateOrAppend(Pending(generation: generation, continuation: continuation))
241238
state.emission = .pending(sends)
242-
return nil
243239
case .awaiting(var nexts):
244240
let next = nexts.removeFirst().continuation
245241
if nexts.count == 0 {
246242
state.emission = .idle
247243
} else {
248244
state.emission = .awaiting(nexts)
249245
}
250-
return UnsafeResumption(continuation: continuation, success: next)
246+
continuation.resume(returning: next)
251247
case .terminated:
252-
return UnsafeResumption(continuation: continuation, success: nil)
248+
continuation.resume(returning: nil)
253249
}
254-
}?.resume()
250+
}
255251
}
256252
continuation?.resume(returning: element)
257253
}
258254

259255
func terminateAll(error: Failure? = nil) {
260-
let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet<Pending>, OrderedSet<Awaiting>) in
256+
state.withCriticalRegion { state in
261257

262258
let nextState: Emission
263259
if let error = error {
@@ -269,29 +265,24 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
269265
switch state.emission {
270266
case .idle:
271267
state.emission = nextState
272-
return ([], [])
273-
case .pending(let nexts):
268+
case .pending(let sends):
274269
state.emission = nextState
275-
return (nexts, [])
270+
for send in sends {
271+
send.continuation?.resume(returning: nil)
272+
}
276273
case .awaiting(let nexts):
277274
state.emission = nextState
278-
return ([], nexts)
275+
if let error = error {
276+
for next in nexts {
277+
next.continuation?.resume(throwing: error)
278+
}
279+
} else {
280+
for next in nexts {
281+
next.continuation?.resume(returning: nil)
282+
}
283+
}
279284
case .terminated:
280-
return ([], [])
281-
}
282-
}
283-
284-
for send in sends {
285-
send.continuation?.resume(returning: nil)
286-
}
287-
288-
if let error = error {
289-
for next in nexts {
290-
next.continuation?.resume(throwing: error)
291-
}
292-
} else {
293-
for next in nexts {
294-
next.continuation?.resume(returning: nil)
285+
break
295286
}
296287
}
297288
}

Sources/AsyncAlgorithms/UnsafeResumption.swift

-37
This file was deleted.

0 commit comments

Comments
 (0)