Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EventStreams: Customisable Terminating Byte Sequence #115

Merged
merged 26 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
4325329
EventStreams: add ability to customise the terminating byte sequence of
paulhdk Sep 17, 2024
0fa72a8
EventStreams: add tests for custom terminating byte sequences
paulhdk Sep 17, 2024
5777ae0
EventStreams: remove non-closure overload incl. its tests
paulhdk Sep 20, 2024
b329be3
EventStreams: rename "terminate" to "while"
paulhdk Sep 20, 2024
625b013
EventStreams: make closure non-optional
paulhdk Sep 20, 2024
a7d3791
Deprecated: mark `asDecodedServerSentEvents` as deprecated
paulhdk Sep 20, 2024
514bbd0
Deprecated: mark `asDecodedServerSentEventsWithJSONData` as deprecated
paulhdk Sep 20, 2024
de8aeca
Update doc comments in Sources/OpenAPIRuntime/EventStreams/ServerSent…
paulhdk Oct 1, 2024
0057abe
Deprecated: mark ServerSentEventsDeserializationSequence's
paulhdk Oct 1, 2024
a346c61
Update doc comment in Sources/OpenAPIRuntime/EventStreams/ServerSentE…
paulhdk Oct 1, 2024
54b1251
Update doc comment in Sources/OpenAPIRuntime/EventStreams/ServerSentE…
paulhdk Oct 1, 2024
feaba23
Update doc comment in Sources/OpenAPIRuntime/EventStreams/ServerSentE…
paulhdk Oct 1, 2024
d286aec
Remove redundant forced type cast in Sources/OpenAPIRuntime/EventStre…
paulhdk Oct 1, 2024
36df44b
EventStreams: formatting
paulhdk Oct 1, 2024
6de6ec2
EventStreams: add doc comment for `ServerSentEventsDeserializationSeq…
paulhdk Oct 1, 2024
4e7bdb5
EventStreams: check for terminating byte sequence after removing trai…
paulhdk Oct 1, 2024
ad8f055
fixup! Deprecated: mark ServerSentEventsDeserializationSequence's `in…
paulhdk Oct 1, 2024
1552be8
EventStreams: store `predicate` closure as an associated value in
paulhdk Oct 1, 2024
2c1b3cb
EventStreams: remove `eventCountOffset` in `Test_ServerSentEventsDeco…
paulhdk Oct 1, 2024
97bda70
Update doc comment in Sources/OpenAPIRuntime/EventStreams/ServerSentE…
paulhdk Oct 2, 2024
0d4345e
EventStream: remove `predicate` property in `ServerSentEventsDeserial…
paulhdk Oct 3, 2024
0b841db
Update doc comment in Sources/OpenAPIRuntime/EventStreams/ServerSentE…
paulhdk Oct 3, 2024
30c9b81
EventStreams: address CI warnings re: doc comments
paulhdk Oct 3, 2024
31462e6
EventStream: only assign state once in each pass in `next()`
paulhdk Oct 3, 2024
86d0966
Merge branch 'main' into customisable-terminating-byte-sequence
czechboy0 Oct 3, 2024
78e19da
Format
paulhdk Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions Sources/OpenAPIRuntime/Deprecated/Deprecated.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,37 @@ extension Configuration {
)
}
}

extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
/// - Returns: A sequence that provides the events.
@available(*, deprecated, renamed: "asDecodedServerSentEvents(while:)") @_disfavoredOverload
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
ServerSentEventsLineDeserializationSequence<Self>
> { asDecodedServerSentEvents(while: { _ in true }) }
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is JSON.
/// - Parameters:
/// - dataType: The type to decode the JSON data into.
/// - decoder: The JSON decoder to use.
/// - Returns: A sequence that provides the events with the decoded JSON data.
@available(*, deprecated, renamed: "asDecodedServerSentEventsWithJSONData(of:decoder:while:)") @_disfavoredOverload
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
of dataType: JSONDataType.Type = JSONDataType.self,
decoder: JSONDecoder = .init()
) -> AsyncThrowingMapSequence<
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
ServerSentEventWithJSONData<JSONDataType>
> { asDecodedServerSentEventsWithJSONData(of: dataType, decoder: decoder, while: { _ in true }) }
}

extension ServerSentEventsDeserializationSequence {
/// Creates a new sequence.
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
@available(*, deprecated, renamed: "init(upstream:while:)") @_disfavoredOverload public init(upstream: Upstream) {
self.init(upstream: upstream, while: { _ in true })
}
}
71 changes: 50 additions & 21 deletions Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,19 @@ where Upstream.Element == ArraySlice<UInt8> {
/// The upstream sequence.
private let upstream: Upstream

/// A closure that determines whether the given byte chunk should be forwarded to the consumer.
/// - Parameter: A byte chunk.
/// - Returns: `true` if the byte chunk should be forwarded, `false` if this byte chunk is the terminating sequence.
private let predicate: @Sendable (ArraySlice<UInt8>) -> Bool

/// Creates a new sequence.
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
public init(upstream: Upstream) { self.upstream = upstream }
paulhdk marked this conversation as resolved.
Show resolved Hide resolved
/// - Parameters:
/// - upstream: The upstream sequence of arbitrary byte chunks.
/// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer.
public init(upstream: Upstream, while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool) {
self.upstream = upstream
self.predicate = predicate
}
}

extension ServerSentEventsDeserializationSequence: AsyncSequence {
Expand All @@ -46,7 +56,16 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
var upstream: UpstreamIterator

/// The state machine of the iterator.
var stateMachine: StateMachine = .init()
var stateMachine: StateMachine

/// Creates a new sequence.
/// - Parameters:
/// - upstream: The upstream sequence of arbitrary byte chunks.
/// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer.
init(upstream: UpstreamIterator, while predicate: @escaping ((ArraySlice<UInt8>) -> Bool)) {
self.upstream = upstream
self.stateMachine = .init(while: predicate)
}

/// Asynchronously advances to the next element and returns it, or ends the
/// sequence if there is no next element.
Expand All @@ -70,7 +89,7 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
/// Creates the asynchronous iterator that produces elements of this
/// asynchronous sequence.
public func makeAsyncIterator() -> Iterator<Upstream.AsyncIterator> {
Iterator(upstream: upstream.makeAsyncIterator())
Iterator(upstream: upstream.makeAsyncIterator(), while: predicate)
}
}

Expand All @@ -79,26 +98,30 @@ extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
/// - Parameter: A closure that determines whether the given byte chunk should be forwarded to the consumer.
/// - Returns: A sequence that provides the events.
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
ServerSentEventsLineDeserializationSequence<Self>
> { .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self)) }

public func asDecodedServerSentEvents(
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) -> ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>> {
.init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self), while: predicate)
}
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is JSON.
/// - Parameters:
/// - dataType: The type to decode the JSON data into.
/// - decoder: The JSON decoder to use.
/// - predicate: A closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
/// - Returns: A sequence that provides the events with the decoded JSON data.
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
of dataType: JSONDataType.Type = JSONDataType.self,
decoder: JSONDecoder = .init()
decoder: JSONDecoder = .init(),
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) -> AsyncThrowingMapSequence<
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
ServerSentEventWithJSONData<JSONDataType>
> {
asDecodedServerSentEvents()
asDecodedServerSentEvents(while: predicate)
.map { event in
ServerSentEventWithJSONData(
event: event.event,
Expand All @@ -118,10 +141,10 @@ extension ServerSentEventsDeserializationSequence.Iterator {
struct StateMachine {

/// The possible states of the state machine.
enum State: Hashable {
enum State {

/// Accumulating an event, which hasn't been emitted yet.
case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice<UInt8>])
case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice<UInt8>], predicate: (ArraySlice<UInt8>) -> Bool)

/// Finished, the terminal state.
case finished
Expand All @@ -134,7 +157,9 @@ extension ServerSentEventsDeserializationSequence.Iterator {
private(set) var state: State

/// Creates a new state machine.
init() { self.state = .accumulatingEvent(.init(), buffer: []) }
init(while predicate: @escaping (ArraySlice<UInt8>) -> Bool) {
self.state = .accumulatingEvent(.init(), buffer: [], predicate: predicate)
}

/// An action returned by the `next` method.
enum NextAction {
Expand All @@ -156,20 +181,24 @@ extension ServerSentEventsDeserializationSequence.Iterator {
/// - Returns: An action to perform.
mutating func next() -> NextAction {
switch state {
case .accumulatingEvent(var event, var buffer):
case .accumulatingEvent(var event, var buffer, let predicate):
guard let line = buffer.first else { return .needsMore }
state = .mutating
buffer.removeFirst()
if line.isEmpty {
// Dispatch the accumulated event.
state = .accumulatingEvent(.init(), buffer: buffer)
// If the last character of data is a newline, strip it.
if event.data?.hasSuffix("\n") ?? false { event.data?.removeLast() }
if let data = event.data, !predicate(ArraySlice(data.utf8)) {
state = .finished
return .returnNil
}
state = .accumulatingEvent(.init(), buffer: buffer, predicate: predicate)
return .emitEvent(event)
}
if line.first! == ASCII.colon {
// A comment, skip this line.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
}
// Parse the field name and value.
Expand All @@ -193,7 +222,7 @@ extension ServerSentEventsDeserializationSequence.Iterator {
}
guard let value else {
// An unknown type of event, skip.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
}
// Process the field.
Expand All @@ -214,11 +243,11 @@ extension ServerSentEventsDeserializationSequence.Iterator {
}
default:
// An unknown or invalid field, skip.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
}
// Processed the field, continue.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
case .finished: return .returnNil
case .mutating: preconditionFailure("Invalid state")
Expand All @@ -240,11 +269,11 @@ extension ServerSentEventsDeserializationSequence.Iterator {
/// - Returns: An action to perform.
mutating func receivedValue(_ value: ArraySlice<UInt8>?) -> ReceivedValueAction {
switch state {
case .accumulatingEvent(let event, var buffer):
case .accumulatingEvent(let event, var buffer, let predicate):
if let value {
state = .mutating
buffer.append(value)
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
} else {
// If no value is received, drop the existing event on the floor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ import XCTest
import Foundation

final class Test_ServerSentEventsDecoding: Test_Runtime {
func _test(input: String, output: [ServerSentEvent], file: StaticString = #filePath, line: UInt = #line)
async throws
{
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents()
func _test(
input: String,
output: [ServerSentEvent],
file: StaticString = #filePath,
line: UInt = #line,
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) async throws {
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents(while: predicate)
let events = try await [ServerSentEvent](collecting: sequence)
XCTAssertEqual(events.count, output.count, file: file, line: line)
for (index, linePair) in zip(events, output).enumerated() {
let (actualEvent, expectedEvent) = linePair
XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line)
}
}

func test() async throws {
// Simple event.
try await _test(
Expand Down Expand Up @@ -83,22 +88,40 @@ final class Test_ServerSentEventsDecoding: Test_Runtime {
.init(id: "123", data: "This is a message with an ID."),
]
)

try await _test(
input: #"""
data: hello
data: world

data: [DONE]

data: hello2
data: world2


"""#,
output: [.init(data: "hello\nworld")],
while: { incomingData in incomingData != ArraySlice<UInt8>(Data("[DONE]".utf8)) }
)
}
func _testJSONData<JSONType: Decodable & Hashable & Sendable>(
input: String,
output: [ServerSentEventWithJSONData<JSONType>],
file: StaticString = #filePath,
line: UInt = #line
line: UInt = #line,
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) async throws {
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8))
.asDecodedServerSentEventsWithJSONData(of: JSONType.self)
.asDecodedServerSentEventsWithJSONData(of: JSONType.self, while: predicate)
let events = try await [ServerSentEventWithJSONData<JSONType>](collecting: sequence)
XCTAssertEqual(events.count, output.count, file: file, line: line)
for (index, linePair) in zip(events, output).enumerated() {
let (actualEvent, expectedEvent) = linePair
XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line)
}
}

struct TestEvent: Decodable, Hashable, Sendable { var index: Int }
func testJSONData() async throws {
// Simple event.
Expand All @@ -121,6 +144,33 @@ final class Test_ServerSentEventsDecoding: Test_Runtime {
.init(event: "event2", data: TestEvent(index: 2), id: "2"),
]
)

try await _testJSONData(
input: #"""
event: event1
id: 1
data: {"index":1}

event: event2
id: 2
data: {
data: "index": 2
data: }

data: [DONE]

event: event3
id: 1
data: {"index":3}


"""#,
output: [
.init(event: "event1", data: TestEvent(index: 1), id: "1"),
.init(event: "event2", data: TestEvent(index: 2), id: "2"),
],
while: { incomingData in incomingData != ArraySlice<UInt8>(Data("[DONE]".utf8)) }
)
}
}

Expand Down