diff --git a/Makefile b/Makefile index 6c19be4c..cae43698 100644 --- a/Makefile +++ b/Makefile @@ -46,30 +46,39 @@ runconformance: runcrosstests runconformancenew runconformancenew: generate $(CONNECT_CONFORMANCE) ## Run the new conformance test suite. ./gradlew $(GRADLE_ARGS) conformance:client:google-java:installDist conformance:client:google-javalite:installDist $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/lite-unary-config.yaml \ - --known-failing conformance/client/known-failing-cases.txt -- \ + --known-failing conformance/client/known-failing-unary-cases.txt -- \ conformance/client/google-javalite/build/install/google-javalite/bin/google-javalite \ --style suspend $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/lite-unary-config.yaml \ - --known-failing conformance/client/known-failing-cases.txt -- \ + --known-failing conformance/client/known-failing-unary-cases.txt -- \ conformance/client/google-javalite/build/install/google-javalite/bin/google-javalite \ --style callback $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/lite-unary-config.yaml \ - --known-failing conformance/client/known-failing-cases.txt -- \ + --known-failing conformance/client/known-failing-unary-cases.txt -- \ conformance/client/google-javalite/build/install/google-javalite/bin/google-javalite \ --style blocking $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/standard-unary-config.yaml \ - --known-failing conformance/client/known-failing-cases.txt -- \ + --known-failing conformance/client/known-failing-unary-cases.txt -- \ conformance/client/google-java/build/install/google-java/bin/google-java \ --style suspend $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/standard-unary-config.yaml \ - --known-failing conformance/client/known-failing-cases.txt -- \ + --known-failing conformance/client/known-failing-unary-cases.txt -- \ conformance/client/google-java/build/install/google-java/bin/google-java \ --style callback $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/standard-unary-config.yaml \ - --known-failing conformance/client/known-failing-cases.txt -- \ + --known-failing conformance/client/known-failing-unary-cases.txt -- \ conformance/client/google-java/build/install/google-java/bin/google-java \ --style blocking -# TODO: streaming conformance test cases + +# TODO: Add streaming conformance tests. Currently, a small number of the test cases +# are flaky, so leaving this commented out for now. +# (Will continue investigating and address soon). +# $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/lite-stream-config.yaml \ +# --known-failing conformance/client/known-failing-stream-cases.txt -- \ +# conformance/client/google-javalite/build/install/google-javalite/bin/google-javalite +# $(CONNECT_CONFORMANCE) -v --mode client --conf conformance/client/standard-stream-config.yaml \ +# --known-failing conformance/client/known-failing-stream-cases.txt -- \ +# conformance/client/google-java/build/install/google-java/bin/google-java .PHONY: runcrosstests runcrosstests: generate ## Run the old cross-test suite. diff --git a/conformance/client/google-java/src/main/kotlin/com/connectrpc/conformance/client/java/JavaHelpers.kt b/conformance/client/google-java/src/main/kotlin/com/connectrpc/conformance/client/java/JavaHelpers.kt index f44a5bb7..32a8892a 100644 --- a/conformance/client/google-java/src/main/kotlin/com/connectrpc/conformance/client/java/JavaHelpers.kt +++ b/conformance/client/google-java/src/main/kotlin/com/connectrpc/conformance/client/java/JavaHelpers.kt @@ -180,6 +180,8 @@ class JavaHelpers { get() = msg.serverTlsCert override val clientTlsCreds: TlsCreds? get() = if (msg.hasClientTlsCreds()) TlsCredsImpl(msg.clientTlsCreds) else null + override val receiveLimitBytes: Int + get() = msg.messageReceiveLimit override val timeoutMs: Int get() = msg.timeoutMs override val requestDelayMs: Int diff --git a/conformance/client/google-javalite/src/main/kotlin/com/connectrpc/conformance/client/javalite/JavaLiteHelpers.kt b/conformance/client/google-javalite/src/main/kotlin/com/connectrpc/conformance/client/javalite/JavaLiteHelpers.kt index c1a0f8e5..62da7e72 100644 --- a/conformance/client/google-javalite/src/main/kotlin/com/connectrpc/conformance/client/javalite/JavaLiteHelpers.kt +++ b/conformance/client/google-javalite/src/main/kotlin/com/connectrpc/conformance/client/javalite/JavaLiteHelpers.kt @@ -162,6 +162,8 @@ class JavaLiteHelpers { get() = msg.serverTlsCert override val clientTlsCreds: TlsCreds? get() = if (msg.hasClientTlsCreds()) TlsCredsImpl(msg.clientTlsCreds) else null + override val receiveLimitBytes: Int + get() = msg.messageReceiveLimit override val timeoutMs: Int get() = msg.timeoutMs override val requestDelayMs: Int diff --git a/conformance/client/known-failing-stream-cases.txt b/conformance/client/known-failing-stream-cases.txt new file mode 100644 index 00000000..a84a57cd --- /dev/null +++ b/conformance/client/known-failing-stream-cases.txt @@ -0,0 +1,15 @@ +# OkHttp seems to have a bug where timeout is not properly +# enforced when request body is full-duplex. +Timeouts/HTTPVersion:2/**/bidi half duplex timeout +Timeouts/HTTPVersion:2/**/bidi full duplex timeout + +# Connect-kotlin does not have a way to limit the size of messages +# received. It probably should. Despite this, many cases in this suite +# still pass, so they are likely not exercising what we think they are. +# TODO: add flag to config yaml for whether implementation supports +# a receive size limit +Client Message Size/**/Compression:COMPRESSION_GZIP/TLS:false/**/client stream first request exceeds client limit +Client Message Size/**/Compression:COMPRESSION_GZIP/TLS:false/**/client stream subsequent request exceeds client limit +Client Message Size/**/Compression:COMPRESSION_GZIP/TLS:false/**/client stream all requests equal to client limit +Client Message Size/**/Compression:COMPRESSION_GZIP/TLS:false/**/server stream request equal to client limit +Client Message Size/**/Compression:COMPRESSION_GZIP/TLS:false/**/server stream request exceeds client limit diff --git a/conformance/client/known-failing-cases.txt b/conformance/client/known-failing-unary-cases.txt similarity index 100% rename from conformance/client/known-failing-cases.txt rename to conformance/client/known-failing-unary-cases.txt diff --git a/conformance/client/lite-stream-config.yaml b/conformance/client/lite-stream-config.yaml new file mode 100644 index 00000000..5cdb0c8b --- /dev/null +++ b/conformance/client/lite-stream-config.yaml @@ -0,0 +1,25 @@ +# This configures the features that this client +# supports and that will be verified by the +# conformance test suite. +features: + versions: + - HTTP_VERSION_1 + - HTTP_VERSION_2 + protocols: + - PROTOCOL_CONNECT + - PROTOCOL_GRPC + - PROTOCOL_GRPC_WEB + codecs: + - CODEC_PROTO + # Lite does not support JSON + compressions: + - COMPRESSION_IDENTITY + - COMPRESSION_GZIP + streamTypes: + # This config file only runs stream RPC test cases. + - STREAM_TYPE_CLIENT_STREAM + - STREAM_TYPE_SERVER_STREAM + - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM + - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM + # TODO: get client certs working and uncomment this + #supportsTlsClientCerts: true diff --git a/conformance/client/lite-unary-config.yaml b/conformance/client/lite-unary-config.yaml index 7887034d..ac4aeb7c 100644 --- a/conformance/client/lite-unary-config.yaml +++ b/conformance/client/lite-unary-config.yaml @@ -20,5 +20,5 @@ features: # so that we can run them all three ways: suspend, # callback, and blocking. - STREAM_TYPE_UNARY - supportsTlsClientCerts: true - supportsHalfDuplexBidiOverHttp1: true + # TODO: get client certs working and uncomment this + #supportsTlsClientCerts: true diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/Client.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/Client.kt index 6264aea0..0c7f8f4a 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/Client.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/Client.kt @@ -15,6 +15,8 @@ package com.connectrpc.conformance.client import com.connectrpc.Code +import com.connectrpc.ConnectException +import com.connectrpc.Headers import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.ResponseMessage @@ -31,6 +33,7 @@ import com.connectrpc.conformance.client.adapt.ClientCompatRequest.StreamType import com.connectrpc.conformance.client.adapt.ClientResponseResult import com.connectrpc.conformance.client.adapt.ClientStreamClient import com.connectrpc.conformance.client.adapt.Invoker +import com.connectrpc.conformance.client.adapt.ResponseStream import com.connectrpc.conformance.client.adapt.ServerStreamClient import com.connectrpc.conformance.client.adapt.UnaryClient import com.connectrpc.impl.ProtocolClient @@ -38,7 +41,9 @@ import com.connectrpc.okhttp.ConnectOkHttpClient import com.connectrpc.protocols.GETConfiguration import com.google.protobuf.MessageLite import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.delay +import kotlinx.coroutines.launch import okhttp3.OkHttpClient import okhttp3.tls.HandshakeCertificates import okhttp3.tls.HeldCertificate @@ -75,6 +80,9 @@ class Client( private const val UNARY_REQUEST_NAME = "connectrpc.conformance.v1.UnaryRequest" private const val IDEMPOTENT_UNARY_REQUEST_NAME = "connectrpc.conformance.v1.IdempotentUnaryRequest" private const val UNIMPLEMENTED_REQUEST_NAME = "connectrpc.conformance.v1.UnimplementedRequest" + private const val CLIENT_STREAM_REQUEST_NAME = "connectrpc.conformance.v1.ClientStreamRequest" + private const val SERVER_STREAM_REQUEST_NAME = "connectrpc.conformance.v1.ServerStreamRequest" + private const val BIDI_STREAM_REQUEST_NAME = "connectrpc.conformance.v1.BidiStreamRequest" } suspend fun handle(req: ClientCompatRequest): ClientResponseResult { @@ -137,43 +145,52 @@ class Client( // So this case means no cancellation. } } - return when (val result = resp.await()) { - is ResponseMessage.Success -> { - if (result.code != Code.OK) { - throw RuntimeException("RPC was successful but ended with non-OK code ${result.code}") - } - - ClientResponseResult( - headers = result.headers, - payloads = listOf(payloadExtractor(result.message)), - trailers = result.trailers, - ) - } - is ResponseMessage.Failure -> { - if (result.code != result.cause.code) { - throw RuntimeException("RPC result has mismatching codes: ${result.code} != ${result.cause.code}") - } - if (args.verbosity > 2) { - System.err.println("* client: RPC failed with code ${result.code}") - result.cause.printStackTrace() - } - ClientResponseResult( - headers = result.headers, - error = result.cause, - trailers = result.trailers, - ) - } - } + return unaryResult(0, resp.await()) } private suspend fun handleClient( client: ClientStreamClient, req: ClientCompatRequest, - ): ClientResponseResult { + ): ClientResponseResult = coroutineScope { if (req.streamType != StreamType.CLIENT_STREAM) { throw RuntimeException("specified method ${req.method} is client-stream but stream type indicates ${req.streamType}") } - TODO("implement me") + if (req.cancel != null && + req.cancel !is Cancel.BeforeCloseSend && + req.cancel !is Cancel.AfterCloseSendMs + ) { + throw RuntimeException("client stream calls can only support `BeforeCloseSend` and 'AfterCloseSendMs' cancellation field, instead got ${req.cancel!!::class.simpleName}") + } + val stream = client.execute(req.requestHeaders) + var numUnsent = 0 + for (i in req.requestMessages.indices) { + if (req.requestDelayMs > 0) { + delay(req.requestDelayMs.toLong()) + } + val msg = fromAny(req.requestMessages[i], client.reqTemplate, CLIENT_STREAM_REQUEST_NAME) + try { + stream.send(msg) + } catch (_: Exception) { + numUnsent = req.requestMessages.size - i + break + } + } + when (val cancel = req.cancel) { + is Cancel.BeforeCloseSend -> { + stream.cancel() + } + is Cancel.AfterCloseSendMs -> { + launch { + delay(cancel.millis.toLong()) + stream.cancel() + } + } + else -> { + // We already validated the case above. + // So this case means no cancellation. + } + } + return@coroutineScope unaryResult(numUnsent, stream.closeAndReceive()) } private suspend fun handleServer( @@ -183,7 +200,23 @@ class Client( if (req.streamType != StreamType.SERVER_STREAM) { throw RuntimeException("specified method ${req.method} is server-stream but stream type indicates ${req.streamType}") } - TODO("implement me") + if (req.requestMessages.size != 1) { + throw RuntimeException("server-stream calls should indicate exactly one request message, got ${req.requestMessages.size}") + } + if (req.cancel != null && + req.cancel !is Cancel.AfterCloseSendMs && + req.cancel !is Cancel.AfterNumResponses + ) { + throw RuntimeException("server stream calls can only support `AfterCloseSendMs` and 'AfterNumResponses' cancellation field, instead got ${req.cancel!!::class.simpleName}") + } + val msg = fromAny(req.requestMessages[0], client.reqTemplate, SERVER_STREAM_REQUEST_NAME) + val stream = client.execute(msg, req.requestHeaders) + val cancel = req.cancel + if (cancel is Cancel.AfterCloseSendMs) { + delay(cancel.millis.toLong()) + stream.close() + } + return streamResult(0, stream, cancel) } private suspend fun handleBidi( @@ -204,14 +237,172 @@ class Client( client: BidiStreamClient, req: ClientCompatRequest, ): ClientResponseResult { - TODO("implement me") + val stream = client.execute(req.requestHeaders) + var numUnsent = 0 + for (i in req.requestMessages.indices) { + if (req.requestDelayMs > 0) { + delay(req.requestDelayMs.toLong()) + } + val msg = fromAny(req.requestMessages[i], client.reqTemplate, BIDI_STREAM_REQUEST_NAME) + try { + stream.requests.send(msg) + } catch (_: Exception) { + numUnsent = req.requestMessages.size - i + break + } + } + val cancel = req.cancel + when (cancel) { + is Cancel.BeforeCloseSend -> { + stream.responses.close() // cancel + stream.requests.close() // close send + } + is Cancel.AfterCloseSendMs -> { + stream.requests.close() // close send + delay(cancel.millis.toLong()) + stream.responses.close() // cancel + } + else -> { + stream.requests.close() // close send + } + } + return streamResult(numUnsent, stream.responses, cancel) } private suspend fun handleFullDuplexBidi( client: BidiStreamClient, req: ClientCompatRequest, ): ClientResponseResult { - TODO("implement me") + val stream = client.execute(req.requestHeaders) + val cancel = req.cancel + val payloads: MutableList = mutableListOf() + for (i in req.requestMessages.indices) { + if (req.requestDelayMs > 0) { + delay(req.requestDelayMs.toLong()) + } + val msg = fromAny(req.requestMessages[i], client.reqTemplate, BIDI_STREAM_REQUEST_NAME) + try { + stream.requests.send(msg) + } catch (_: Exception) { + // Ignore. We should see it again below when we receive the response. + } + + // In full-duplex mode, we read the response after writing request, + // to interleave the requests and responses. + if (i == 0 && cancel is Cancel.AfterNumResponses && cancel.num == 0) { + stream.responses.close() + } + try { + val resp = stream.responses.messages.receive() + payloads.add(payloadExtractor(resp)) + if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) { + stream.responses.close() + } + } catch (ex: ConnectException) { + return ClientResponseResult( + headers = stream.responses.headers(), + payloads = payloads, + error = ex, + trailers = ex.metadata, + numUnsentRequests = req.requestMessages.size - i, + ) + } + } + when (cancel) { + is Cancel.BeforeCloseSend -> { + stream.responses.close() // cancel + stream.requests.close() // close send + } + is Cancel.AfterCloseSendMs -> { + stream.requests.close() // close send + delay(cancel.millis.toLong()) + stream.responses.close() // cancel + } + else -> { + stream.requests.close() // close send + } + } + + // Drain the response, in case there are any other messages. + var connEx: ConnectException? = null + var trailers: Headers + try { + for (resp in stream.responses.messages) { + payloads.add(payloadExtractor(resp)) + if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) { + stream.responses.close() + } + } + trailers = stream.responses.trailers() + } catch (ex: ConnectException) { + connEx = ex + trailers = ex.metadata + } finally { + stream.responses.close() + } + return ClientResponseResult( + headers = stream.responses.headers(), + payloads = payloads, + error = connEx, + trailers = trailers, + ) + } + + private fun unaryResult(numUnsent: Int, result: ResponseMessage): ClientResponseResult { + return when (result) { + is ResponseMessage.Success -> { + if (result.code != Code.OK) { + throw RuntimeException("RPC was successful but ended with non-OK code ${result.code}") + } + ClientResponseResult( + headers = result.headers, + payloads = listOf(payloadExtractor(result.message)), + trailers = result.trailers, + numUnsentRequests = numUnsent, + ) + } + is ResponseMessage.Failure -> { + if (result.code != result.cause.code) { + throw RuntimeException("RPC result has mismatching codes: ${result.code} != ${result.cause.code}") + } + ClientResponseResult( + headers = result.headers, + error = result.cause, + trailers = result.trailers, + numUnsentRequests = numUnsent, + ) + } + } + } + + private suspend fun streamResult(numUnsent: Int, stream: ResponseStream, cancel: Cancel?): ClientResponseResult { + val payloads: MutableList = mutableListOf() + var connEx: ConnectException? = null + var trailers: Headers + try { + if (cancel is Cancel.AfterNumResponses && cancel.num == 0) { + stream.close() + } + for (resp in stream.messages) { + payloads.add(payloadExtractor(resp)) + if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) { + stream.close() + } + } + trailers = stream.trailers() + } catch (ex: ConnectException) { + connEx = ex + trailers = ex.metadata + } finally { + stream.close() + } + return ClientResponseResult( + headers = stream.headers(), + payloads = payloads, + error = connEx, + trailers = trailers, + numUnsentRequests = numUnsent, + ) } private fun getClient(req: ClientCompatRequest): Pair { @@ -230,7 +421,7 @@ class Client( if (req.timeoutMs != 0) { clientBuilder = clientBuilder.callTimeout(Duration.ofMillis(req.timeoutMs.toLong())) } - + // TODO: need to support max receive bytes and use req.receiveLimitBytes val getConfig = if (req.useGetHttpMethod) GETConfiguration.Enabled else GETConfiguration.Disabled val requestCompression = if (req.compression == Compression.GZIP) { diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/ConformanceClientLoop.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/ConformanceClientLoop.kt index eea1e2e0..8c0d148f 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/ConformanceClientLoop.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/ConformanceClientLoop.kt @@ -58,6 +58,13 @@ class ConformanceClientLoop( } result = ClientCompatResponse.Result.ErrorResult(msg) } + if (result is ClientCompatResponse.Result.ResponseResult && result.response.error != null) { + if (verbosity > 2) { + val ex = result.response.error!! + System.err.println("* client: RPC failed with code ${ex.code}") + ex.printStackTrace() + } + } writeResponse( output, ClientCompatResponse( diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/BidiStreamClient.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/BidiStreamClient.kt index 0865269b..0e372c40 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/BidiStreamClient.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/BidiStreamClient.kt @@ -44,20 +44,18 @@ abstract class BidiStreamClient( * @param Resp The response message type */ interface BidiStream { - fun requests(): RequestStream - fun responses(): ResponseStream + val requests: RequestStream + val responses: ResponseStream companion object { fun new(underlying: BidirectionalStreamInterface): BidiStream { val reqStream = RequestStream.new(underlying) val respStream = ResponseStream.new(underlying) return object : BidiStream { - override fun requests(): RequestStream { - return reqStream - } + override val requests: RequestStream + get() = reqStream - override fun responses(): ResponseStream { - return respStream - } + override val responses: ResponseStream + get() = respStream } } } diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientCompatRequest.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientCompatRequest.kt index 36422e5e..ab8f3002 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientCompatRequest.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientCompatRequest.kt @@ -43,6 +43,7 @@ interface ClientCompatRequest { val port: Int val serverTlsCert: ByteString val clientTlsCreds: TlsCreds? + val receiveLimitBytes: Int val timeoutMs: Int val requestDelayMs: Int val useGetHttpMethod: Boolean diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientStreamClient.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientStreamClient.kt index 01691ec4..7701d0ee 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientStreamClient.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ClientStreamClient.kt @@ -46,6 +46,7 @@ abstract class ClientStreamClient( interface ClientStream { suspend fun send(req: Req) suspend fun closeAndReceive(): ResponseMessage + suspend fun cancel() companion object { fun new(underlying: ClientOnlyStreamInterface): ClientStream { @@ -74,10 +75,14 @@ abstract class ClientStreamClient( cause = connectException, code = connectException.code, headers = underlying.responseHeaders().await(), - trailers = underlying.responseTrailers().await(), + trailers = connectException.metadata, ) } } + + override suspend fun cancel() { + underlying.cancel() + } } } } diff --git a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ResponseStream.kt b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ResponseStream.kt index cb401ab4..9a17dc75 100644 --- a/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ResponseStream.kt +++ b/conformance/client/src/main/kotlin/com/connectrpc/conformance/client/adapt/ResponseStream.kt @@ -30,7 +30,7 @@ import kotlinx.coroutines.channels.ReceiveChannel * @param Resp The response message type */ interface ResponseStream { - fun messages(): ReceiveChannel + val messages: ReceiveChannel suspend fun headers(): Headers @@ -41,9 +41,8 @@ interface ResponseStream { companion object { fun new(underlying: BidirectionalStreamInterface): ResponseStream { return object : ResponseStream { - override fun messages(): ReceiveChannel { - return underlying.responseChannel() - } + override val messages: ReceiveChannel + get() = underlying.responseChannel() override suspend fun headers(): Headers { return underlying.responseHeaders().await() @@ -61,9 +60,8 @@ interface ResponseStream { fun new(underlying: ServerOnlyStreamInterface): ResponseStream { return object : ResponseStream { - override fun messages(): ReceiveChannel { - return underlying.responseChannel() - } + override val messages: ReceiveChannel + get() = underlying.responseChannel() override suspend fun headers(): Headers { return underlying.responseHeaders().await() diff --git a/conformance/client/standard-stream-config.yaml b/conformance/client/standard-stream-config.yaml new file mode 100644 index 00000000..1b3f284f --- /dev/null +++ b/conformance/client/standard-stream-config.yaml @@ -0,0 +1,25 @@ +# This configures the features that this client +# supports and that will be verified by the +# conformance test suite. +features: + versions: + - HTTP_VERSION_1 + - HTTP_VERSION_2 + protocols: + - PROTOCOL_CONNECT + - PROTOCOL_GRPC + - PROTOCOL_GRPC_WEB + codecs: + - CODEC_PROTO + - CODEC_JSON + compressions: + - COMPRESSION_IDENTITY + - COMPRESSION_GZIP + streamTypes: + # This config file only runs stream RPC test cases. + - STREAM_TYPE_CLIENT_STREAM + - STREAM_TYPE_SERVER_STREAM + - STREAM_TYPE_HALF_DUPLEX_BIDI_STREAM + - STREAM_TYPE_FULL_DUPLEX_BIDI_STREAM + # TODO: get client certs working and uncomment this + #supportsTlsClientCerts: true diff --git a/conformance/client/standard-unary-config.yaml b/conformance/client/standard-unary-config.yaml index dee374aa..fb10401a 100644 --- a/conformance/client/standard-unary-config.yaml +++ b/conformance/client/standard-unary-config.yaml @@ -20,5 +20,5 @@ features: # so that we can run them all three ways: suspend, # callback, and blocking. - STREAM_TYPE_UNARY - supportsTlsClientCerts: true - supportsHalfDuplexBidiOverHttp1: true + # TODO: get client certs working and uncomment this + #supportsTlsClientCerts: true diff --git a/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt b/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt index be625bf4..9fe150a5 100644 --- a/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt @@ -59,6 +59,12 @@ interface ClientOnlyStreamInterface { */ fun sendClose() + /** + * Cancels the stream. This closes both send and receive sides of the stream + * without awaiting any server reply. + */ + fun cancel() + /** * Determine if the underlying client send stream is closed. * diff --git a/library/src/main/kotlin/com/connectrpc/Code.kt b/library/src/main/kotlin/com/connectrpc/Code.kt index 585a5261..47e2b268 100644 --- a/library/src/main/kotlin/com/connectrpc/Code.kt +++ b/library/src/main/kotlin/com/connectrpc/Code.kt @@ -60,6 +60,9 @@ enum class Code(val codeName: String, val value: Int) { } } fun fromName(name: String?): Code { + if (name == null) { + return UNKNOWN + } for (value in values()) { if (value.codeName == name) { return value @@ -71,7 +74,8 @@ enum class Code(val codeName: String, val value: Int) { if (value == null) { return UNKNOWN } - return values().first { code -> code.value == value } + val code = values().firstOrNull { code -> code.value == value } + return code ?: UNKNOWN } } } diff --git a/library/src/main/kotlin/com/connectrpc/Interceptor.kt b/library/src/main/kotlin/com/connectrpc/Interceptor.kt index 9e83fd9d..f9bbf441 100644 --- a/library/src/main/kotlin/com/connectrpc/Interceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/Interceptor.kt @@ -25,6 +25,10 @@ import okio.Buffer * Interceptors are expected to be instantiated once per request/stream. */ interface Interceptor { + // TODO: This interface and the StreamResult class should be internal. + // User-provided interceptors should have a better API that provides + // similar higher-level abstraction as the stream interfaces. + /** * Invoked when a unary call is started. Provides a set of closures that will be called * as the request progresses, allowing the interceptor to alter request/response data. diff --git a/library/src/main/kotlin/com/connectrpc/StreamResult.kt b/library/src/main/kotlin/com/connectrpc/StreamResult.kt index 46e721b3..cfa97b79 100644 --- a/library/src/main/kotlin/com/connectrpc/StreamResult.kt +++ b/library/src/main/kotlin/com/connectrpc/StreamResult.kt @@ -75,6 +75,8 @@ sealed class StreamResult { /** * Fold the different results into a nullable single type. + * Unlike `fold`, the caller can omit some transformations, + * which default to returning null. * * @param onHeaders Transform a Header result. * @param onMessage Transform a Message result. @@ -85,16 +87,6 @@ sealed class StreamResult { onMessage: (Message) -> Result? = { null }, onCompletion: (Complete) -> Result? = { null }, ): Result? { - return when (this) { - is Headers -> { - onHeaders(this) - } - is Message -> { - onMessage(this) - } - is Complete -> { - onCompletion(this) - } - } + return fold(onHeaders, onMessage, onCompletion) } } diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt index 4f1e2baf..7834a5e0 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt @@ -44,7 +44,7 @@ interface HTTPClientInterface { * * @return The created stream. */ - fun stream(request: HTTPRequest, onResult: suspend (StreamResult) -> Unit): Stream + fun stream(request: HTTPRequest, duplex: Boolean, onResult: suspend (StreamResult) -> Unit): Stream } class Stream( diff --git a/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt b/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt index a712a6e4..583793f6 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt @@ -58,6 +58,10 @@ internal class ClientOnlyStream( return messageStream.sendClose() } + override fun cancel() { + return messageStream.receiveClose() + } + override fun isSendClosed(): Boolean { return messageStream.isSendClosed() } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 5ec3f534..6e7ebc28 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -25,6 +25,7 @@ import com.connectrpc.ProtocolClientInterface import com.connectrpc.ResponseMessage import com.connectrpc.ServerOnlyStreamInterface import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.UnaryBlockingCall import com.connectrpc.http.Cancelable import com.connectrpc.http.HTTPClientInterface @@ -207,7 +208,7 @@ class ProtocolClient( val streamFunc = config.createStreamingInterceptorChain() val finalRequest = streamFunc.requestFunction(request) var isComplete = false - val httpStream = httpClient.stream(finalRequest) { initialResult -> + val httpStream = httpClient.stream(finalRequest, methodSpec.streamType == StreamType.BIDI) { initialResult -> if (isComplete) { // No-op on remaining handlers after a completion. return@stream diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index 8c9af33c..c1a032bf 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -227,8 +227,8 @@ internal class ConnectInterceptor( return StreamResult.Complete(Code.UNKNOWN, e) } val metadata = endStreamResponseJSON.metadata?.toLowercase() - if (endStreamResponseJSON.error?.code == null) { - return StreamResult.Complete(Code.OK, trailers = metadata ?: emptyMap()) + if (endStreamResponseJSON.error == null) { + return StreamResult.Complete(Code.OK, trailers = metadata.orEmpty()) } val code = Code.fromName(endStreamResponseJSON.error.code) StreamResult.Complete( diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletion.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletion.kt index e82cf82c..b3adb806 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletion.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletion.kt @@ -47,7 +47,7 @@ internal data class GRPCCompletion( if (cause != null || code != Code.OK) { return ConnectException( - code = code, + code = if (code == Code.OK) Code.UNKNOWN else code, errorDetailParser = serializationStrategy.errorDetailParser(), message = message.utf8(), exception = cause, diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt index e079f43d..6c1ebb99 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt @@ -205,10 +205,10 @@ internal class GRPCWebInterceptor( if (headerByte.and(TRAILERS_BIT) == TRAILERS_BIT) { val streamTrailers = parseGrpcWebTrailer(unpackedMessage) val completion = completionParser.parse(emptyMap(), streamTrailers) - val code = completion!!.code + val code = completion?.code ?: Code.UNKNOWN return@fold StreamResult.Complete( code = code, - cause = completion.toConnectExceptionOrNull(serializationStrategy), + cause = completion?.toConnectExceptionOrNull(serializationStrategy) ?: ConnectException(code), trailers = streamTrailers, ) } diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 04ad2a94..16f96164 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -85,7 +85,7 @@ class ProtocolClientTest { createMethodSpec(StreamType.BIDI), ) val captor = argumentCaptor() - verify(httpClient).stream(captor.capture(), any()) + verify(httpClient).stream(captor.capture(), true, any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } } @@ -103,7 +103,7 @@ class ProtocolClientTest { createMethodSpec(StreamType.BIDI), ) val captor = argumentCaptor() - verify(httpClient).stream(captor.capture(), any()) + verify(httpClient).stream(captor.capture(), true, any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index 658e2542..24443373 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -66,7 +66,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( newCall.enqueue( object : Callback { override fun onFailure(call: Call, e: IOException) { - val code = codeFromIOException(e) + val code = codeFromException(newCall.isCanceled(), e) onResult( HTTPResponse( code = code, @@ -123,9 +123,10 @@ class ConnectOkHttpClient @JvmOverloads constructor( override fun stream( request: HTTPRequest, + duplex: Boolean, onResult: suspend (StreamResult) -> Unit, ): Stream { - return streamClient.initializeStream(request.httpMethod, request, onResult) + return streamClient.initializeStream(request.httpMethod, request, duplex, onResult) } } @@ -136,15 +137,12 @@ internal fun Headers.toLowerCaseKeysMultiMap(): Map> { ) } -internal fun codeFromIOException(e: IOException): Code { +internal fun codeFromException(callCanceled: Boolean, e: Exception): Code { return if ((e is InterruptedIOException && e.message == "timeout") || e is SocketTimeoutException ) { Code.DEADLINE_EXCEEDED - } else if (e.message?.lowercase() == "canceled") { - // TODO: Figure out what, if anything, actually throws an exception - // with this message. It seems more likely that a JVM or - // Kotlin coroutine exception would spell it with two Ls. + } else if (e is IOException && callCanceled) { Code.CANCELED } else { Code.UNKNOWN diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt index dc44de7f..f7e5cb0d 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt @@ -28,7 +28,6 @@ import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.RequestBody import okhttp3.Response -import okhttp3.internal.http2.StreamResetException import okio.Buffer import okio.BufferedSink import okio.BufferedSource @@ -45,11 +44,12 @@ import java.util.concurrent.atomic.AtomicBoolean internal fun OkHttpClient.initializeStream( method: String, request: HTTPRequest, + duplex: Boolean, onResult: suspend (StreamResult) -> Unit, ): Stream { val isSendClosed = AtomicBoolean(false) val isReceiveClosed = AtomicBoolean(false) - val duplexRequestBody = PipeDuplexRequestBody(request.contentType.toMediaType()) + val duplexRequestBody = PipeRequestBody(duplex, request.contentType.toMediaType()) val builder = Request.Builder() .url(request.url) .method(method, duplexRequestBody) @@ -60,7 +60,7 @@ internal fun OkHttpClient.initializeStream( } val callRequest = builder.build() val call = newCall(callRequest) - call.enqueue(ResponseCallback(onResult, isReceiveClosed)) + call.enqueue(ResponseCallback(onResult)) return Stream( onSend = { buffer -> if (!isSendClosed.get()) { @@ -82,11 +82,10 @@ internal fun OkHttpClient.initializeStream( private class ResponseCallback( private val onResult: suspend (StreamResult) -> Unit, - private val isClosed: AtomicBoolean, ) : Callback { override fun onFailure(call: Call, e: IOException) { runBlocking { - onResult(StreamResult.Complete(codeFromIOException(e), cause = e)) + onResult(StreamResult.Complete(codeFromException(call.isCanceled(), e), cause = e)) } } @@ -107,9 +106,9 @@ private class ResponseCallback( } response.use { resp -> resp.body!!.source().use { sourceBuffer -> - var exception: Throwable? = null + var exception: Exception? = null try { - while (!sourceBuffer.safeExhausted() && !isClosed.get()) { + while (!sourceBuffer.exhausted()) { val buffer = readStream(sourceBuffer) val streamResult = StreamResult.Message( message = buffer, @@ -122,7 +121,7 @@ private class ResponseCallback( // If trailers are not yet communicated. // This is the final chance to notify trailers to the consumer. val finalResult = StreamResult.Complete( - code = code, + code = if (exception != null) codeFromException(call.isCanceled(), exception) else code, trailers = response.safeTrailers() ?: emptyMap(), cause = exception, ) @@ -133,21 +132,18 @@ private class ResponseCallback( } } - private fun BufferedSource.safeExhausted(): Boolean { - return try { - exhausted() - } catch (e: StreamResetException) { - true - } - } - private fun Response.safeTrailers(): Map>? { - return try { - if (body?.source()?.safeExhausted() == false) { + try { + if (body?.source()?.exhausted() == false) { // Assuming this means that trailers are not available. // Returning null to signal trailers are "missing". return null } + } catch (e: Exception) { + return null + } + + return try { trailers().toLowerCaseKeysMultiMap() } catch (_: Throwable) { // Something went terribly wrong. @@ -175,7 +171,8 @@ private class ResponseCallback( } } -internal class PipeDuplexRequestBody( +internal class PipeRequestBody( + private val duplex: Boolean, private val contentType: MediaType?, pipeMaxBufferSize: Long = 1024 * 1024, ) : RequestBody() { @@ -201,7 +198,7 @@ internal class PipeDuplexRequestBody( pipe.fold(sink) } - override fun isDuplex() = true + override fun isDuplex() = duplex override fun isOneShot() = true