diff --git a/Libraries/Connect/Internal/Streaming/ClientOnlyAsyncStream.swift b/Libraries/Connect/Internal/Streaming/ClientOnlyAsyncStream.swift index 28311d1..07e18da 100644 --- a/Libraries/Connect/Internal/Streaming/ClientOnlyAsyncStream.swift +++ b/Libraries/Connect/Internal/Streaming/ClientOnlyAsyncStream.swift @@ -24,18 +24,21 @@ import Foundation final class ClientOnlyAsyncStream< Input: ProtobufMessage, Output: ProtobufMessage >: BidirectionalAsyncStream { - private let receivedMessageCount = Locked(0) + private let receivedResults = Locked([StreamResult]()) override func handleResultFromServer(_ result: StreamResult) { - let receivedMessageCount = self.receivedMessageCount.perform { value in - if case .message = result { - value += 1 + let (isComplete, results) = self.receivedResults.perform { results in + results.append(result) + if case .complete = result { + return (true, ClientOnlyStreamValidation.validatedFinalClientStreamResults(results)) + } else { + return (false, []) } - return value } - super.handleResultFromServer( - result.validatedForClientStream(receivedMessageCount: receivedMessageCount) - ) + guard isComplete else { + return + } + results.forEach(super.handleResultFromServer) } } diff --git a/Libraries/Connect/Internal/Streaming/ClientOnlyStream.swift b/Libraries/Connect/Internal/Streaming/ClientOnlyStream.swift index 3a14bce..85dec6a 100644 --- a/Libraries/Connect/Internal/Streaming/ClientOnlyStream.swift +++ b/Libraries/Connect/Internal/Streaming/ClientOnlyStream.swift @@ -20,7 +20,7 @@ import SwiftProtobuf /// supporting both callbacks and async/await. This is internal to the package, and not public. final class ClientOnlyStream: @unchecked Sendable { private let onResult: @Sendable (StreamResult) -> Void - private let receivedMessageCount = Locked(0) + private let receivedResults = Locked([StreamResult]()) /// Callbacks used to send outbound data and close the stream. /// Optional because these callbacks are not available until the stream is initialized. private var requestCallbacks: RequestCallbacks? @@ -49,13 +49,18 @@ final class ClientOnlyStream: @ /// /// - parameter result: The new result that was received. func handleResultFromServer(_ result: StreamResult) { - let receivedMessageCount = self.receivedMessageCount.perform { value in - if case .message = result { - value += 1 + let (isComplete, results) = self.receivedResults.perform { results in + results.append(result) + if case .complete = result { + return (true, ClientOnlyStreamValidation.validatedFinalClientStreamResults(results)) + } else { + return (false, []) } - return value } - self.onResult(result.validatedForClientStream(receivedMessageCount: receivedMessageCount)) + guard isComplete else { + return + } + results.forEach(self.onResult) } } diff --git a/Libraries/Connect/Internal/Streaming/ClientOnlyStreamValidation.swift b/Libraries/Connect/Internal/Streaming/ClientOnlyStreamValidation.swift new file mode 100644 index 0000000..101ea08 --- /dev/null +++ b/Libraries/Connect/Internal/Streaming/ClientOnlyStreamValidation.swift @@ -0,0 +1,65 @@ +// Copyright 2022-2024 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// Namespace for performing client-only stream validation. +enum ClientOnlyStreamValidation { + /// Applies some validations which are only relevant for client-only streams. + /// + /// Should be called after all values have been received over a client stream. Since client + /// streams only expect 1 result, all values returned from the server should be buffered before + /// being validated here and returned to the caller. + /// + /// - parameter results: The buffered list of results to validate. + /// + /// - returns: The list of stream results which should be returned to the caller. + static func validatedFinalClientStreamResults( + _ results: [StreamResult] + ) -> [StreamResult] { + var messageCount = 0 + for result in results { + switch result { + case .headers: + continue + case .message: + messageCount += 1 + case .complete(let code, _, _): + if code != .ok { + return results + } + } + } + + if messageCount < 1 { + return [ + .complete( + code: .internalError, error: ConnectError( + code: .unimplemented, message: "unary stream has no messages" + ), trailers: nil + ), + ] + } else if messageCount > 1 { + return [ + .complete( + code: .internalError, error: ConnectError( + code: .unimplemented, message: "unary stream has multiple messages" + ), trailers: nil + ), + ] + } else { + return results + } + } +} diff --git a/Libraries/Connect/Internal/Streaming/StreamResult+ClientOnlyStream.swift b/Libraries/Connect/Internal/Streaming/StreamResult+ClientOnlyStream.swift deleted file mode 100644 index c0f7d2a..0000000 --- a/Libraries/Connect/Internal/Streaming/StreamResult+ClientOnlyStream.swift +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2022-2024 The Connect Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import Foundation - -extension StreamResult { - /// Applies some validations which are only relevant for client-only streams. - /// - /// - parameter receivedMessageCount: The number of response messages that have been received so - /// far, including the one being validated. - /// - /// - returns: The validated stream result, which may have been transformed into an error. - func validatedForClientStream(receivedMessageCount: Int) -> Self { - switch self { - case .headers: - return self - case .message: - if receivedMessageCount > 1 { - return .complete( - code: .internalError, - error: ConnectError( - code: .unimplemented, message: "unary stream has multiple messages" - ), - trailers: nil - ) - } else { - return self - } - case .complete(let code, _, _): - if code == .ok && receivedMessageCount < 1 { - return .complete( - code: .internalError, - error: ConnectError( - code: .unimplemented, message: "unary stream has no messages" - ), - trailers: nil - ) - } else { - return self - } - } - } -}