Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request body iteration specialisation #611

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 45 additions & 51 deletions Sources/HummingbirdCore/Request/RequestBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public struct RequestBody: Sendable, AsyncSequence {
@usableFromInline
internal enum _Backing: Sendable {
case byteBuffer(ByteBuffer)
case stream(AnyAsyncSequence<ByteBuffer>)
case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody)
case anyAsyncSequence(AnyAsyncSequence<ByteBuffer>)
}

@usableFromInline
Expand All @@ -37,15 +38,23 @@ public struct RequestBody: Sendable, AsyncSequence {

/// Initialise ``RequestBody`` from ByteBuffer
/// - Parameter buffer: ByteBuffer
@inlinable
public init(buffer: ByteBuffer) {
self.init(.byteBuffer(buffer))
}

/// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers
/// - Parameter asyncSequence: AsyncSequence
@inlinable
package init(nioAsyncChannelInbound: NIOAsyncChannelRequestBody) {
self.init(.nioAsyncChannelRequestBody(nioAsyncChannelInbound))
}

/// Initialise ``RequestBody`` from AsyncSequence of ByteBuffers
/// - Parameter asyncSequence: AsyncSequence
@inlinable
public init<AS: AsyncSequence & Sendable>(asyncSequence: AS) where AS.Element == ByteBuffer {
self.init(.stream(.init(asyncSequence)))
self.init(.anyAsyncSequence(.init(asyncSequence)))
}
}

Expand All @@ -55,26 +64,51 @@ extension RequestBody {

public struct AsyncIterator: AsyncIteratorProtocol {
@usableFromInline
var iterator: AnyAsyncSequence<ByteBuffer>.AsyncIterator
internal enum _Backing {
case byteBuffer(ByteBuffer)
case nioAsyncChannelRequestBody(NIOAsyncChannelRequestBody.AsyncIterator)
case anyAsyncSequence(AnyAsyncSequence<ByteBuffer>.AsyncIterator)
case done
}

@usableFromInline
init(_ iterator: AnyAsyncSequence<ByteBuffer>.AsyncIterator) {
self.iterator = iterator
var _backing: _Backing

@usableFromInline
init(_ backing: _Backing) {
self._backing = backing
}

@inlinable
public mutating func next() async throws -> ByteBuffer? {
try await self.iterator.next()
switch self._backing {
case .byteBuffer(let buffer):
self._backing = .done
return buffer

case .nioAsyncChannelRequestBody(var iterator):
let next = try await iterator.next()
self._backing = .nioAsyncChannelRequestBody(iterator)
return next

case .anyAsyncSequence(let iterator):
return try await iterator.next()

case .done:
return nil
}
}
}

@inlinable
public func makeAsyncIterator() -> AsyncIterator {
switch self._backing {
case .byteBuffer(let buffer):
return .init(AnyAsyncSequence<ByteBuffer>(ByteBufferRequestBody(byteBuffer: buffer)).makeAsyncIterator())
case .stream(let stream):
return .init(stream.makeAsyncIterator())
return .init(.byteBuffer(buffer))
case .nioAsyncChannelRequestBody(let requestBody):
return .init(.nioAsyncChannelRequestBody(requestBody.makeAsyncIterator()))
case .anyAsyncSequence(let stream):
return .init(.anyAsyncSequence(stream.makeAsyncIterator()))
}
}
}
Expand Down Expand Up @@ -195,7 +229,8 @@ extension RequestBody {
/// Request body that is a stream of ByteBuffers sourced from a NIOAsyncChannelInboundStream.
///
/// This is a unicast async sequence that allows a single iterator to be created.
public final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence {
@usableFromInline
package struct NIOAsyncChannelRequestBody: Sendable, AsyncSequence {
public typealias Element = ByteBuffer
public typealias InboundStream = NIOAsyncChannelInboundStream<HTTPRequestPart>

Expand Down Expand Up @@ -256,44 +291,3 @@ public final class NIOAsyncChannelRequestBody: Sendable, AsyncSequence {
return AsyncIterator(underlyingIterator: self.underlyingIterator.wrappedValue, done: done)
}
}

/// Request body stream that is a single ByteBuffer
///
/// This is used when converting a ByteBuffer back to a stream of ByteBuffers
@usableFromInline
struct ByteBufferRequestBody: Sendable, AsyncSequence {
@usableFromInline
typealias Element = ByteBuffer

@usableFromInline
init(byteBuffer: ByteBuffer) {
self.byteBuffer = byteBuffer
}

@usableFromInline
struct AsyncIterator: AsyncIteratorProtocol {
@usableFromInline
var byteBuffer: ByteBuffer
@usableFromInline
var iterated: Bool

init(byteBuffer: ByteBuffer) {
self.byteBuffer = byteBuffer
self.iterated = false
}

@inlinable
mutating func next() async throws -> ByteBuffer? {
guard self.iterated == false else { return nil }
self.iterated = true
return self.byteBuffer
}
}

@usableFromInline
func makeAsyncIterator() -> AsyncIterator {
.init(byteBuffer: self.byteBuffer)
}

let byteBuffer: ByteBuffer
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ extension HTTPChannelHandler {

while true {
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(asyncSequence: bodyStream))
let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream))
let responseWriter = ResponseWriter(outbound: outbound)
do {
try await self.responder(request, responseWriter, asyncChannel.channel)
Expand Down
2 changes: 1 addition & 1 deletion Sources/HummingbirdHTTP2/HTTP2StreamChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct HTTP2StreamChannel: ServerChildChannel {
throw HTTPChannelError.unexpectedHTTPPart(part)
}
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(asyncSequence: bodyStream))
let request = Request(head: head, body: .init(nioAsyncChannelInbound: bodyStream))
let responseWriter = ResponseWriter(outbound: outbound)
try await self.responder(request, responseWriter, asyncChannel.channel)
}
Expand Down
Loading