From 2678d892d615c7e6292ea278139ce4f97f3fb66b Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Wed, 18 Oct 2023 08:22:57 -0500 Subject: [PATCH] Improve API for bidi and server streaming calls Instead of requiring callers to handle oneOf(Headers,Message,Trailers) objects in each bidi or server streaming call, instead just change the response channel to return the response message type. If an error occurs at the end of the call (due to non-zero grpc-status), then cancel the channel with an exception. --- .../com/connectrpc/conformance/Conformance.kt | 126 +++++------------- .../examples/android/ElizaChatActivity.kt | 42 +++--- .../com/connectrpc/examples/kotlin/Main.kt | 25 +--- .../com/connectrpc/examples/kotlin/Main.kt | 23 +--- .../BidirectionalStreamInterface.kt | 6 +- .../connectrpc/ClientOnlyStreamInterface.kt | 4 +- .../connectrpc/ServerOnlyStreamInterface.kt | 3 +- .../com/connectrpc/UnaryBlockingCall.kt | 6 +- .../connectrpc/impl/BidirectionalStream.kt | 5 +- .../com/connectrpc/impl/ClientOnlyStream.kt | 59 +------- .../com/connectrpc/impl/ProtocolClient.kt | 29 ++-- .../com/connectrpc/impl/ServerOnlyStream.kt | 5 +- .../protocols/ConnectInterceptor.kt | 6 +- .../protocols/GRPCWebInterceptor.kt | 4 +- .../com/connectrpc/okhttp/OkHttpStream.kt | 8 +- 15 files changed, 103 insertions(+), 248 deletions(-) diff --git a/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt b/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt index 332cbcd6..de14af60 100644 --- a/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt +++ b/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt @@ -16,15 +16,13 @@ package com.connectrpc.conformance import com.connectrpc.Code import com.connectrpc.ConnectException -import com.connectrpc.Headers import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression -import com.connectrpc.StreamResult -import com.connectrpc.Trailers import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.conformance.ssl.sslContext import com.connectrpc.conformance.v1.ErrorDetail import com.connectrpc.conformance.v1.PayloadType +import com.connectrpc.conformance.v1.StreamingOutputCallResponse import com.connectrpc.conformance.v1.TestServiceClient import com.connectrpc.conformance.v1.UnimplementedServiceClient import com.connectrpc.conformance.v1.echoStatus @@ -43,7 +41,6 @@ import com.google.protobuf.ByteString import com.google.protobuf.empty import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async -import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext @@ -63,7 +60,6 @@ import java.time.Duration import java.util.Base64 import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean @RunWith(Parameterized::class) class Conformance( @@ -177,17 +173,18 @@ class Conformance( responseParameters += params }, ).getOrThrow() - val results = streamResults(stream.resultChannel()) - assertThat(results.cause).isNull() - assertThat(results.code).isEqualTo(Code.OK) - assertThat(results.messages.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE)) - assertThat(results.messages.map { it.payload.body.size() }).isEqualTo(sizes) + val responses = mutableListOf() + for (response in stream.responseChannel()) { + responses.add(response) + } + assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE)) + assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes) } @Test fun pingPong(): Unit = runBlocking { val stream = testServiceConnectClient.fullDuplexCall() - var readHeaders = false + val responseChannel = stream.responseChannel() listOf(512_000, 16, 2_028, 65_536).forEach { val param = responseParameters { size = it } stream.send( @@ -196,25 +193,14 @@ class Conformance( responseParameters += param }, ).getOrThrow() - if (!readHeaders) { - val headersResult = stream.resultChannel().receive() - assertThat(headersResult).isInstanceOf(StreamResult.Headers::class.java) - readHeaders = true - } - val result = stream.resultChannel().receive() - assertThat(result).isInstanceOf(StreamResult.Message::class.java) - val messageResult = result as StreamResult.Message - val payload = messageResult.message.payload + val response = responseChannel.receive() + val payload = response.payload assertThat(payload.type).isEqualTo(PayloadType.COMPRESSABLE) assertThat(payload.body).hasSize(it) } stream.sendClose() - val results = streamResults(stream.resultChannel()) // We've already read all the messages - assertThat(results.messages).isEmpty() - assertThat(results.cause).isNull() - assertThat(results.code).isEqualTo(Code.OK) - stream.receiveClose() + assertThat(responseChannel.receiveCatching().isClosed).isTrue() } @Test @@ -244,15 +230,17 @@ class Conformance( val countDownLatch = CountDownLatch(1) withContext(Dispatchers.IO) { val job = async { + val responses = mutableListOf() try { - val result = streamResults(stream.resultChannel()) - assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes) - assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED) - assertThat(result.cause).isInstanceOf(ConnectException::class.java) - val connectException = result.cause as ConnectException - assertThat(connectException.code).isEqualTo(Code.RESOURCE_EXHAUSTED) - assertThat(connectException.message).isEqualTo("soirée 🎉") - assertThat(connectException.unpackedDetails(ErrorDetail::class)).containsExactly( + for (response in stream.responseChannel()) { + responses.add(response) + } + fail("expected call to fail with ConnectException") + } catch (e: ConnectException) { + assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes) + assertThat(e.code).isEqualTo(Code.RESOURCE_EXHAUSTED) + assertThat(e.message).isEqualTo("soirée 🎉") + assertThat(e.unpackedDetails(ErrorDetail::class)).containsExactly( expectedErrorDetail, ) } finally { @@ -363,10 +351,11 @@ class Conformance( withContext(Dispatchers.IO) { val job = launch { try { - val result = streamResults(stream.resultChannel()) - assertThat(result.cause).isInstanceOf(ConnectException::class.java) - assertThat(result.code) - .withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${result.code}" } + stream.responseChannel().receive() + fail("unexpected ConnectException to be thrown") + } catch (e: ConnectException) { + assertThat(e.code) + .withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${e.code}" } .isEqualTo(Code.DEADLINE_EXCEEDED) } finally { countDownLatch.countDown() @@ -437,11 +426,10 @@ class Conformance( withContext(Dispatchers.IO) { val job = async { try { - val result = streamResults(stream.resultChannel()) - assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED) - assertThat(result.cause).isInstanceOf(ConnectException::class.java) - val exception = result.cause as ConnectException - assertThat(exception.code).isEqualTo(Code.UNIMPLEMENTED) + stream.responseChannel().receive() + fail("expected call to fail with a ConnectException") + } catch (e: ConnectException) { + assertThat(e.code).isEqualTo(Code.UNIMPLEMENTED) } finally { countDownLatch.countDown() } @@ -801,8 +789,8 @@ class Conformance( withContext(Dispatchers.IO) { val job = async { try { - val result = stream.receiveAndClose().getOrThrow() - assertThat(result.aggregatedPayloadSize).isEqualTo(sum) + val response = stream.receiveAndClose() + assertThat(response.aggregatedPayloadSize).isEqualTo(sum) } finally { countDownLatch.countDown() } @@ -813,56 +801,6 @@ class Conformance( } } - private data class ServerStreamingResult( - val headers: Headers, - val messages: List, - val code: Code, - val trailers: Trailers, - val cause: Throwable?, - ) - - /* - * Convenience method to return all results (with sanity checking) for calls which stream results from the server - * (bidi and server streaming). - * - * This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold - * manually in each location. - */ - private suspend fun streamResults(channel: ReceiveChannel>): ServerStreamingResult { - val seenHeaders = AtomicBoolean(false) - var headers: Headers = emptyMap() - val messages: MutableList = mutableListOf() - val seenCompletion = AtomicBoolean(false) - var code: Code = Code.UNKNOWN - var trailers: Headers = emptyMap() - var error: Throwable? = null - for (response in channel) { - response.maybeFold( - onHeaders = { - if (!seenHeaders.compareAndSet(false, true)) { - throw IllegalStateException("multiple onHeaders callbacks") - } - headers = it.headers - }, - onMessage = { - messages.add(it.message) - }, - onCompletion = { - if (!seenCompletion.compareAndSet(false, true)) { - throw IllegalStateException("multiple onCompletion callbacks") - } - code = it.code - trailers = it.trailers - error = it.cause - }, - ) - } - if (!seenCompletion.get()) { - throw IllegalStateException("didn't get completion message") - } - return ServerStreamingResult(headers, messages, code, trailers, error) - } - private fun b64Encode(trailingValue: ByteArray): String { return String(Base64.getEncoder().encode(trailingValue)) } diff --git a/examples/android/src/main/kotlin/com/connectrpc/examples/android/ElizaChatActivity.kt b/examples/android/src/main/kotlin/com/connectrpc/examples/android/ElizaChatActivity.kt index e205a290..63a87923 100644 --- a/examples/android/src/main/kotlin/com/connectrpc/examples/android/ElizaChatActivity.kt +++ b/examples/android/src/main/kotlin/com/connectrpc/examples/android/ElizaChatActivity.kt @@ -21,6 +21,7 @@ import android.widget.TextView import androidx.appcompat.app.AppCompatActivity import androidx.lifecycle.lifecycleScope import androidx.recyclerview.widget.RecyclerView +import com.connectrpc.ConnectException import com.connectrpc.ProtocolClientConfig import com.connectrpc.eliza.v1.ConverseRequest import com.connectrpc.eliza.v1.ElizaServiceClient @@ -135,29 +136,26 @@ class ElizaChatActivity : AppCompatActivity() { lifecycleScope.launch(Dispatchers.IO) { // Initialize a bidi stream with Eliza. val stream = elizaServiceClient.converse() - - for (streamResult in stream.resultChannel()) { - streamResult.maybeFold( - onMessage = { result -> - // A stream message is received: Eliza has said something to us. - val elizaResponse = result.message.sentence - if (elizaResponse?.isNotBlank() == true) { - adapter.add(MessageData(elizaResponse, true)) - } else { - // Something odd occurred. - adapter.add(MessageData("...No response from Eliza...", true)) - } - }, - onCompletion = { - // This should only be called once. - adapter.add( - MessageData( - "Session has ended.", - true, - ), - ) - }, + try { + for (message in stream.responseChannel()) { + // A stream message is received: Eliza has said something to us. + val elizaResponse = message.sentence + if (elizaResponse?.isNotBlank() == true) { + adapter.add(MessageData(elizaResponse, true)) + } else { + // Something odd occurred. + adapter.add(MessageData("...No response from Eliza...", true)) + } + } + // This should only be called once. + adapter.add( + MessageData( + "Session has ended.", + true, + ), ) + } catch (e: ConnectException) { + adapter.add(MessageData("Session failed with code ${e.code}", true)) } lifecycleScope.launch(Dispatchers.Main) { buttonView.setOnClickListener { diff --git a/examples/kotlin-google-java/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt b/examples/kotlin-google-java/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt index 26a994c0..a3986d67 100644 --- a/examples/kotlin-google-java/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt +++ b/examples/kotlin-google-java/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt @@ -14,7 +14,6 @@ package com.connectrpc.examples.kotlin -import com.connectrpc.Code import com.connectrpc.ConnectException import com.connectrpc.ProtocolClientConfig import com.connectrpc.eliza.v1.ElizaServiceClient @@ -63,23 +62,13 @@ class Main { // Add the message the user is sending to the views. stream.send(converseRequest { sentence = "hello" }) stream.sendClose() - for (streamResult in stream.resultChannel()) { - streamResult.maybeFold( - onMessage = { result -> - // Update the view with the response. - val elizaResponse = result.message - println(elizaResponse.sentence) - }, - onCompletion = { result -> - if (result.code != Code.OK) { - val exception = result.connectException() - if (exception != null) { - throw exception - } - throw ConnectException(code = result.code, metadata = result.trailers) - } - }, - ) + try { + for (streamResult in stream.responseChannel()) { + // Update the view with the response. + val elizaResponse = streamResult + println(elizaResponse.sentence) + } + } catch (e: ConnectException) { } } } diff --git a/examples/kotlin-google-javalite/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt b/examples/kotlin-google-javalite/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt index bd3b187b..1edb26be 100644 --- a/examples/kotlin-google-javalite/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt +++ b/examples/kotlin-google-javalite/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt @@ -14,8 +14,6 @@ package com.connectrpc.examples.kotlin -import com.connectrpc.Code -import com.connectrpc.ConnectException import com.connectrpc.ProtocolClientConfig import com.connectrpc.eliza.v1.ElizaServiceClient import com.connectrpc.eliza.v1.converseRequest @@ -34,7 +32,7 @@ class Main { @JvmStatic fun main(args: Array) { runBlocking { - val host = "https://demo.connectrpc.com" + val host = "https://demo.connectrpc.com:444" val okHttpClient = OkHttpClient() .newBuilder() .readTimeout(Duration.ofMinutes(10)) @@ -63,23 +61,8 @@ class Main { // Add the message the user is sending to the views. stream.send(converseRequest { sentence = "hello" }) stream.sendClose() - for (streamResult in stream.resultChannel()) { - streamResult.maybeFold( - onMessage = { result -> - // Update the view with the response. - val elizaResponse = result.message - println(elizaResponse.sentence) - }, - onCompletion = { result -> - if (result.code != Code.OK) { - val exception = result.connectException() - if (exception != null) { - throw exception - } - throw ConnectException(code = result.code, metadata = result.trailers) - } - }, - ) + for (response in stream.responseChannel()) { + println(response.sentence) } } } diff --git a/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt index 8b0a157a..803c4ecf 100644 --- a/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt @@ -21,11 +21,11 @@ import kotlinx.coroutines.channels.ReceiveChannel */ interface BidirectionalStreamInterface { /** - * The Channel for received StreamResults. + * The Channel for responses. * - * @return ReceiveChannel for iterating over the received results. + * @return ReceiveChannel for iterating over the responses. */ - fun resultChannel(): ReceiveChannel> + fun responseChannel(): ReceiveChannel /** * Send a request to the server over the stream. diff --git a/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt index 4c85717d..1323f77a 100644 --- a/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt @@ -29,9 +29,9 @@ interface ClientOnlyStreamInterface { /** * Receive a single response and close the stream. * - * @return the single response [ResponseMessage]. + * @return the single response [Output]. */ - suspend fun receiveAndClose(): ResponseMessage + suspend fun receiveAndClose(): Output /** * Close the stream. No calls to [send] are valid after calling [sendClose]. diff --git a/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt index 9fbf586b..6783ee9a 100644 --- a/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt @@ -15,6 +15,7 @@ package com.connectrpc import kotlinx.coroutines.channels.ReceiveChannel + /** * Represents a server-only stream (a stream where the server streams data to the client after * receiving an initial request) that can send request messages. @@ -25,7 +26,7 @@ interface ServerOnlyStreamInterface { * * @return ReceiveChannel for iterating over the received results. */ - fun resultChannel(): ReceiveChannel> + fun responseChannel(): ReceiveChannel /** * Send a request to the server over the stream and closes the request. diff --git a/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt b/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt index ea28a961..b27c55ce 100644 --- a/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt +++ b/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt @@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicReference */ class UnaryBlockingCall { private var executable: ((ResponseMessage) -> Unit) -> Unit = { } - private var cancel: () -> Unit = { } + private var cancelFn: () -> Unit = { } /** * Execute the underlying request. @@ -43,7 +43,7 @@ class UnaryBlockingCall { * Cancel the underlying request. */ fun cancel() { - cancel() + cancelFn() } /** @@ -54,7 +54,7 @@ class UnaryBlockingCall { * underlying request. */ internal fun setCancel(cancel: () -> Unit) { - this.cancel = cancel + this.cancelFn = cancel } /** diff --git a/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt b/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt index fb0133b7..5048b09b 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt @@ -16,7 +16,6 @@ package com.connectrpc.impl import com.connectrpc.BidirectionalStreamInterface import com.connectrpc.Codec -import com.connectrpc.StreamResult import com.connectrpc.http.Stream import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ReceiveChannel @@ -28,7 +27,7 @@ import java.lang.Exception internal class BidirectionalStream( val stream: Stream, private val requestCodec: Codec, - private val receiveChannel: Channel>, + private val receiveChannel: Channel, ) : BidirectionalStreamInterface { override suspend fun send(input: Input): Result { @@ -40,7 +39,7 @@ internal class BidirectionalStream( return stream.send(msg) } - override fun resultChannel(): ReceiveChannel> { + override fun responseChannel(): ReceiveChannel { return receiveChannel } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt b/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt index e449c7f0..6ab69ae4 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt @@ -18,8 +18,6 @@ import com.connectrpc.BidirectionalStreamInterface import com.connectrpc.ClientOnlyStreamInterface import com.connectrpc.Code import com.connectrpc.ConnectException -import com.connectrpc.Headers -import com.connectrpc.ResponseMessage /** * Concrete implementation of [ClientOnlyStreamInterface]. @@ -31,59 +29,16 @@ internal class ClientOnlyStream( return messageStream.send(input) } - override suspend fun receiveAndClose(): ResponseMessage { - val resultChannel = messageStream.resultChannel() + override suspend fun receiveAndClose(): Output { + val resultChannel = messageStream.responseChannel() try { messageStream.sendClose() - // TODO: Improve this API for consumers. - // We should aim to provide ease of use for callers so they don't need to individually examine each result - // in the channel (headers, 1* messages, completion) and have to resort to fold()/maybeFold() to interpret - // the overall results. - // Additionally, ResponseMessage.Success and ResponseMessage.Failure shouldn't be necessary for client use. - // We should throw ConnectException for failure and only have users have to deal with success messages. - var headers: Headers = emptyMap() - var message: Output? = null - var trailers: Headers = emptyMap() - var code: Code? = null - var error: ConnectException? = null - for (result in resultChannel) { - result.maybeFold( - onHeaders = { - headers = it.headers - }, - onMessage = { - message = it.message - }, - onCompletion = { - code = it.code - trailers = it.trailers - val resultErr = it.cause - if (resultErr != null) { - error = if (resultErr is ConnectException) { - resultErr - } else { - ConnectException(code ?: Code.UNKNOWN, message = error?.message, exception = error, metadata = trailers) - } - } - }, - ) + val message = resultChannel.receive() + val additionalMessage = resultChannel.tryReceive() + if (additionalMessage.isSuccess) { + throw ConnectException(code = Code.UNKNOWN, message = "unary stream has multiple messages") } - if (error != null) { - return ResponseMessage.Failure(error!!, code ?: Code.UNKNOWN, headers, trailers) - } - if (code == null) { - return ResponseMessage.Failure(ConnectException(Code.UNKNOWN, message = "unknown status code"), Code.UNKNOWN, headers, trailers) - } - if (message != null) { - return ResponseMessage.Success(message!!, code!!, headers, trailers) - } - // We didn't receive an error at any point, however we didn't get a response message either. - return ResponseMessage.Failure( - ConnectException(Code.UNKNOWN, message = "missing response message"), - code!!, - headers, - trailers, - ) + return message } finally { resultChannel.cancel() } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 56548a3b..976097d7 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -17,6 +17,7 @@ package com.connectrpc.impl import com.connectrpc.BidirectionalStreamInterface import com.connectrpc.ClientOnlyStreamInterface import com.connectrpc.Code +import com.connectrpc.ConnectException import com.connectrpc.Headers import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig @@ -149,7 +150,7 @@ class ProtocolClient( headers: Headers, methodSpec: MethodSpec, ): ServerOnlyStreamInterface { - val stream = stream(headers, methodSpec) + val stream = bidirectionalStream(methodSpec, headers) return ServerOnlyStream(stream) } @@ -164,8 +165,8 @@ class ProtocolClient( private suspend fun bidirectionalStream( methodSpec: MethodSpec, headers: Headers, - ): BidirectionalStreamInterface = suspendCancellableCoroutine { continuation -> - val channel = Channel>(1) + ): BidirectionalStream = suspendCancellableCoroutine { continuation -> + val channel = Channel(1) val requestCodec = config.serializationStrategy.codec(methodSpec.requestClass) val responseCodec = config.serializationStrategy.codec(methodSpec.responseClass) val request = HTTPRequest( @@ -183,10 +184,9 @@ class ProtocolClient( return@stream } // Pass through the interceptor chain. - val streamResult = streamFunc.streamResultFunction(initialResult) - val result: StreamResult = when (streamResult) { + when (val streamResult = streamFunc.streamResultFunction(initialResult)) { is StreamResult.Headers -> { - StreamResult.Headers(streamResult.headers) + // Not currently used except for interceptors. } is StreamResult.Message -> { @@ -194,26 +194,21 @@ class ProtocolClient( val message = responseCodec.deserialize( streamResult.message, ) - StreamResult.Message(message) + channel.send(message) } catch (e: Throwable) { + channel.close(ConnectException(Code.UNKNOWN, exception = e)) isComplete = true - StreamResult.Complete(Code.UNKNOWN, e) } } is StreamResult.Complete -> { + when (streamResult.code) { + Code.OK -> channel.close() + else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code)) + } isComplete = true - StreamResult.Complete( - streamResult.connectException()?.code ?: Code.OK, - cause = streamResult.cause, - trailers = streamResult.trailers, - ) } } - channel.send(result) - if (isComplete) { - channel.close() - } } continuation.invokeOnCancellation { httpStream.sendClose() diff --git a/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt b/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt index c33a9607..b2d01bbd 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ServerOnlyStream.kt @@ -16,7 +16,6 @@ package com.connectrpc.impl import com.connectrpc.BidirectionalStreamInterface import com.connectrpc.ServerOnlyStreamInterface -import com.connectrpc.StreamResult import kotlinx.coroutines.channels.ReceiveChannel /** @@ -25,8 +24,8 @@ import kotlinx.coroutines.channels.ReceiveChannel internal class ServerOnlyStream( private val messageStream: BidirectionalStreamInterface, ) : ServerOnlyStreamInterface { - override fun resultChannel(): ReceiveChannel> { - return messageStream.resultChannel() + override fun responseChannel(): ReceiveChannel { + return messageStream.responseChannel() } override suspend fun sendAndClose(input: Input): Result { diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index 215986bf..30299d2c 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -167,11 +167,7 @@ internal class ConnectInterceptor( StreamResult.Message(unpackedMessage) } }, - onCompletion = { result -> - val streamTrailers = result.trailers - val error = result.connectException() - StreamResult.Complete(error?.code ?: Code.OK, cause = error, streamTrailers) - }, + onCompletion = { result -> result }, ) streamResult }, diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt index 119cade7..a6999333 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt @@ -214,9 +214,7 @@ internal class GRPCWebInterceptor( } StreamResult.Message(unpackedMessage) }, - onCompletion = { result -> - result - }, + onCompletion = { result -> result }, ) streamResult }, diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt index 1a06abe1..a2dee5da 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt @@ -36,6 +36,7 @@ import okio.Pipe import okio.buffer import java.io.IOException import java.io.InterruptedIOException +import java.net.SocketTimeoutException import java.util.concurrent.atomic.AtomicBoolean /** @@ -87,11 +88,14 @@ private class ResponseCallback( runBlocking { if (e is InterruptedIOException) { if (e.message == "timeout") { - val error = ConnectException(code = Code.DEADLINE_EXCEEDED) - onResult(StreamResult.Complete(Code.DEADLINE_EXCEEDED, cause = error)) + onResult(StreamResult.Complete(Code.DEADLINE_EXCEEDED, cause = e)) return@runBlocking } } + if (e is SocketTimeoutException) { + onResult(StreamResult.Complete(Code.DEADLINE_EXCEEDED, cause = e)) + return@runBlocking + } onResult(StreamResult.Complete(Code.UNKNOWN, cause = e)) } }