From 235028c65380b803955c9e4e2133220044764349 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Fri, 17 May 2024 12:21:12 -0400 Subject: [PATCH] Fix all conformance failures other than timeouts/deadlines (#274) * The first fix is for how trailers-only responses are classified. This was previously just looking for a "grpc-status" key in the headers. If it was present, it was treating it as a trailers-only response, even if there was a body and/or trailers. * The second fix is so that the client reports errors in the face of unexpected response content types. With a little code reorganization, we can improve this logic in the future, to increase code sharing (especially between gRPC and gRPC-Web). --- .../client/known-failing-stream-cases.txt | 3 - .../client/known-failing-unary-cases.txt | 11 ---- .../main/kotlin/com/connectrpc/AnyError.kt | 3 +- .../src/main/kotlin/com/connectrpc/Code.kt | 2 +- .../src/main/kotlin/com/connectrpc/Codec.kt | 1 + .../com/connectrpc/ErrorDetailParser.kt | 6 +- .../com/connectrpc/SerializationStrategy.kt | 2 +- .../kotlin/com/connectrpc/StreamResult.kt | 3 + .../connectrpc/protocols/ConnectConstants.kt | 1 + .../protocols/ConnectInterceptor.kt | 62 ++++++++++++++----- .../connectrpc/protocols/ErrorJSONModels.kt | 8 +++ .../protocols/GRPCCompletionParser.kt | 26 ++++---- .../connectrpc/protocols/GRPCInterceptor.kt | 61 ++++++++++++++---- .../protocols/GRPCWebInterceptor.kt | 62 +++++++++++++++---- .../com/connectrpc/InterceptorChainTest.kt | 9 ++- .../protocols/ConnectInterceptorTest.kt | 5 +- .../protocols/GRPCErrorDetailParserTest.kt | 51 ++++++++++++++- .../protocols/GRPCInterceptorTest.kt | 13 ++-- .../protocols/GRPCWebInterceptorTest.kt | 17 +++-- .../com/connectrpc/okhttp/OkHttpStream.kt | 1 + 20 files changed, 263 insertions(+), 84 deletions(-) diff --git a/conformance/client/known-failing-stream-cases.txt b/conformance/client/known-failing-stream-cases.txt index 9c667b86..0283080c 100644 --- a/conformance/client/known-failing-stream-cases.txt +++ b/conformance/client/known-failing-stream-cases.txt @@ -6,6 +6,3 @@ Timeouts/HTTPVersion:2/**/bidi-stream/** # Deadline headers are not currently set. Deadline Propagation/** - -# Bug: incorrect code attribution for these failures (UNKNOWN instead of INTERNAL) -Connect Unexpected Responses/**/unexpected-stream-codec diff --git a/conformance/client/known-failing-unary-cases.txt b/conformance/client/known-failing-unary-cases.txt index bcdd5e3d..1765a1af 100644 --- a/conformance/client/known-failing-unary-cases.txt +++ b/conformance/client/known-failing-unary-cases.txt @@ -1,13 +1,2 @@ # Deadline headers are not currently set. Deadline Propagation/** - -# Bug: response content-type is not correctly checked -**/unexpected-content-type - -# Bug: "trailers-only" responses are not correctly identified. -# If headers contain "grpc-status", this client assumes it is a -# trailers-only response. However, a trailers-only response should -# instead be identified by lack of body or HTTP trailers. -gRPC Unexpected Responses/**/trailers-only/* -gRPC-Web Unexpected Responses/**/trailers-only/ignore-header-if-body-present - diff --git a/library/src/main/kotlin/com/connectrpc/AnyError.kt b/library/src/main/kotlin/com/connectrpc/AnyError.kt index 8b520f31..046f927b 100644 --- a/library/src/main/kotlin/com/connectrpc/AnyError.kt +++ b/library/src/main/kotlin/com/connectrpc/AnyError.kt @@ -17,7 +17,8 @@ package com.connectrpc import okio.ByteString /** - * This is an Any adapter for various base data types. + * This is a protobuf-runtime-agnostic representation of google.protobuf.Any + * messages, which are used to represent error details in gRPC. */ class AnyError( val typeUrl: String, diff --git a/library/src/main/kotlin/com/connectrpc/Code.kt b/library/src/main/kotlin/com/connectrpc/Code.kt index 3b540ec6..11bcafcd 100644 --- a/library/src/main/kotlin/com/connectrpc/Code.kt +++ b/library/src/main/kotlin/com/connectrpc/Code.kt @@ -36,7 +36,7 @@ enum class Code(val codeName: String, val value: Int) { ABORTED("aborted", 10), OUT_OF_RANGE("out_of_range", 11), UNIMPLEMENTED("unimplemented", 12), - INTERNAL_ERROR("internal", 13), + INTERNAL_ERROR("internal", 13), // TODO: rename enum value to INTERNAL UNAVAILABLE("unavailable", 14), DATA_LOSS("data_loss", 15), UNAUTHENTICATED("unauthenticated", 16), diff --git a/library/src/main/kotlin/com/connectrpc/Codec.kt b/library/src/main/kotlin/com/connectrpc/Codec.kt index 659f0ecd..2c5cab70 100644 --- a/library/src/main/kotlin/com/connectrpc/Codec.kt +++ b/library/src/main/kotlin/com/connectrpc/Codec.kt @@ -32,6 +32,7 @@ const val codecNameJSON = CODEC_NAME_JSON * Defines a type that is capable of encoding and decoding messages using a specific format. */ interface Codec { + // TODO: remove this method or unify somehow with SerializationStrategy.serializationName? /** * @return The name of the codec's format (e.g., "json", "proto"). Usually consumed * in the form of adding the `content-type` header via "application/{name}". diff --git a/library/src/main/kotlin/com/connectrpc/ErrorDetailParser.kt b/library/src/main/kotlin/com/connectrpc/ErrorDetailParser.kt index ea238848..7be806ea 100644 --- a/library/src/main/kotlin/com/connectrpc/ErrorDetailParser.kt +++ b/library/src/main/kotlin/com/connectrpc/ErrorDetailParser.kt @@ -24,12 +24,14 @@ import kotlin.reflect.KClass */ interface ErrorDetailParser { /** - * Unpack the underlying payload into the input class type. + * Unpack the given Any payload into the input class type. */ fun unpack(any: AnyError, clazz: KClass): E? /** - * Parse payload for a list of error details. + * Parse the given bytes for a list of error details. The given + * bytes will be the serialized form of a google.rpc.Status + * Protobuf message. */ fun parseDetails(bytes: ByteArray): List } diff --git a/library/src/main/kotlin/com/connectrpc/SerializationStrategy.kt b/library/src/main/kotlin/com/connectrpc/SerializationStrategy.kt index 83620454..8bf4027c 100644 --- a/library/src/main/kotlin/com/connectrpc/SerializationStrategy.kt +++ b/library/src/main/kotlin/com/connectrpc/SerializationStrategy.kt @@ -24,7 +24,7 @@ import kotlin.reflect.KClass interface SerializationStrategy { /** - * The name of the serialization. Used in the content-encoding + * The name of the serialization. Used in the content-type * header. */ fun serializationName(): String diff --git a/library/src/main/kotlin/com/connectrpc/StreamResult.kt b/library/src/main/kotlin/com/connectrpc/StreamResult.kt index 7252c239..a30b418b 100644 --- a/library/src/main/kotlin/com/connectrpc/StreamResult.kt +++ b/library/src/main/kotlin/com/connectrpc/StreamResult.kt @@ -22,6 +22,9 @@ package com.connectrpc sealed class StreamResult { // Headers have been received over the stream. class Headers(val headers: com.connectrpc.Headers) : StreamResult() { + // TODO: This should include an HTTP status code, too. Computing an RPC code + // from the HTTP status code should be part of the protocol impl, not + // pushed down to the HTTPClientInterface impl. override fun toString(): String { return "Headers{headers=$headers}" } diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectConstants.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectConstants.kt index 87fbb90c..4b0da2e4 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectConstants.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectConstants.kt @@ -16,6 +16,7 @@ package com.connectrpc.protocols const val ACCEPT_ENCODING = "accept-encoding" const val CONTENT_ENCODING = "content-encoding" +const val CONTENT_TYPE = "content-type" const val CONNECT_STREAMING_CONTENT_ENCODING = "connect-content-encoding" const val CONNECT_STREAMING_ACCEPT_ENCODING = "connect-accept-encoding" const val CONNECT_PROTOCOL_VERSION_KEY = "connect-protocol-version" diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index 8f6af925..5a2dde7a 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -94,37 +94,53 @@ internal class ConnectInterceptor( } val trailers = mutableMapOf>() trailers.putAll(response.headers.toTrailers()) - trailers.putAll(response.trailers) - val responseHeaders = + val headers = response.headers.filter { entry -> !entry.key.startsWith("trailer-") } - val compressionPool = clientConfig.compressionPool(responseHeaders[CONTENT_ENCODING]?.first()) + val compressionPool = clientConfig.compressionPool(headers[CONTENT_ENCODING]?.first()) val responseBody = try { compressionPool?.decompress(response.message.buffer) ?: response.message.buffer } catch (e: Exception) { return@UnaryFunction response.clone( message = Buffer(), - headers = responseHeaders, + headers = headers, trailers = trailers, cause = ConnectException( code = Code.INTERNAL_ERROR, message = e.message, exception = e, + metadata = headers.plus(trailers), ), ) } + val contentType = headers[CONTENT_TYPE]?.first() ?: "" val exception: ConnectException? val message: Buffer if (response.status != 200) { - exception = parseConnectUnaryException(response.status, responseHeaders.plus(trailers), responseBody) + exception = parseConnectUnaryException(response.status, contentType, headers.plus(trailers), responseBody) // We've already read the response body to parse an error - don't read again. message = Buffer() } else { - exception = null message = responseBody + val isValidContentType = + (serializationStrategy.serializationName() == "json" && contentTypeIsJSON(contentType)) || + contentType == "application/" + serializationStrategy.serializationName() + if (isValidContentType) { + exception = null + } else { + // If content-type looks like it could be an RPC server's response, consider + // this an internal error. Otherwise, we infer a code from the HTTP status, + // which means a code of UNKNOWN since HTTP status is 200. + val code = if (contentType.startsWith("application/")) Code.INTERNAL_ERROR else Code.UNKNOWN + exception = ConnectException( + code = code, + message = "unexpected content-type: $contentType", + metadata = headers.plus(trailers), + ) + } } response.clone( message = message, - headers = responseHeaders, + headers = headers, trailers = trailers, cause = exception, ) @@ -161,9 +177,25 @@ internal class ConnectInterceptor( val streamResult: StreamResult = res.fold( onHeaders = { result -> responseHeaders = result.headers - responseCompressionPool = - clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first()) - StreamResult.Headers(responseHeaders) + val contentType = responseHeaders[CONTENT_TYPE]?.first() ?: "" + val isValidContentType = contentType == "application/connect+" + serializationStrategy.serializationName() + if (!isValidContentType) { + // If content-type looks like it could be an RPC server's response, consider + // this an internal error. Otherwise, we infer a code from the HTTP status, + // which means a code of UNKNOWN since HTTP status is 200. + val code = if (contentType.startsWith("application/connect+")) Code.INTERNAL_ERROR else Code.UNKNOWN + StreamResult.Complete( + ConnectException( + code = code, + message = "unexpected content-type: $contentType", + metadata = responseHeaders, + ), + ) + } else { + responseCompressionPool = + clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first()) + StreamResult.Headers(responseHeaders) + } }, onMessage = { result -> val (headerByte, unpackedMessage) = Envelope.unpackWithHeaderByte( @@ -196,7 +228,7 @@ internal class ConnectInterceptor( ): UnaryHTTPRequest { val serializationStrategy = clientConfig.serializationStrategy val requestCodec = serializationStrategy.codec(request.methodSpec.requestClass) - val url = getUrlFromMethodSpec( + val url = constructURLForGETRequest( request, requestCodec, finalRequestBody, @@ -204,7 +236,7 @@ internal class ConnectInterceptor( ) return request.clone( url = url, - contentType = "application/${requestCodec.encodingName()}", + contentType = "", headers = request.headers, methodSpec = request.methodSpec, httpMethod = HTTPMethod.GET, @@ -244,9 +276,9 @@ internal class ConnectInterceptor( } } - private fun parseConnectUnaryException(httpStatus: Int?, metadata: Headers, source: Buffer?): ConnectException { + private fun parseConnectUnaryException(httpStatus: Int?, contentType: String, metadata: Headers, source: Buffer?): ConnectException { val code = Code.fromHTTPStatus(httpStatus) - if (source == null) { + if (source == null || !contentTypeIsJSON(contentType)) { return ConnectException(code, "unexpected status code: $httpStatus") } return source.use { bufferedSource -> @@ -298,7 +330,7 @@ private fun Headers.toTrailers(): Trailers { return trailers } -private fun getUrlFromMethodSpec( +private fun constructURLForGETRequest( httpRequest: HTTPRequest, codec: Codec<*>, payload: Buffer, diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ErrorJSONModels.kt b/library/src/main/kotlin/com/connectrpc/protocols/ErrorJSONModels.kt index c90293b7..35a8f3c1 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ErrorJSONModels.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ErrorJSONModels.kt @@ -36,3 +36,11 @@ internal class EndStreamResponseJSON( @Json(name = "error") val error: ErrorPayloadJSON?, @Json(name = "metadata") val metadata: Headers?, ) + +internal fun contentTypeIsJSON(contentType: String): Boolean { + // TODO: This could be more robust, like actually parsing the content-type. + // There exists a good helper for that, but it's in okhttp, which we intentionally + // don't have as a dep for this module, which aims to be agnostic of the actual + // HTTP client implementation to use. + return contentType == "application/json" || contentType == "application/json; charset=utf-8" +} diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletionParser.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletionParser.kt index b86e09d8..f156c818 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletionParser.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCCompletionParser.kt @@ -33,26 +33,24 @@ internal class GRPCCompletionParser( * * Returns an "absent" completion if unable to be parsed. */ - internal fun parse(headers: Headers, trailers: Trailers): GRPCCompletion { + internal fun parse(headers: Headers, hasBody: Boolean, trailers: Trailers): GRPCCompletion { val statusCode: Int val statusMetadata: Map> - val statusFromHeaders = parseStatus(headers) val trailersOnly: Boolean - if (statusFromHeaders == null) { - statusCode = parseStatus(trailers) - ?: return GRPCCompletion( - present = false, - code = Code.INTERNAL_ERROR, - message = "protocol error: status is missing from trailers", - metadata = trailers, - ) - statusMetadata = trailers - trailersOnly = false - } else { - statusCode = statusFromHeaders + if (!hasBody && trailers.isEmpty()) { statusMetadata = headers trailersOnly = true + } else { + statusMetadata = trailers + trailersOnly = false } + statusCode = parseStatus(statusMetadata) + ?: return GRPCCompletion( + present = false, + code = Code.UNKNOWN, + message = "protocol error: status is missing from trailers", + metadata = statusMetadata, + ) // Note: we report combined headers and trailers as exception meta, so // caller doesn't have to check both, which is particularly important // since server could actually serialize them together in a single bucket diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt index 7d69f0b3..d66ff6dd 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt @@ -37,6 +37,7 @@ internal class GRPCInterceptor( private val completionParser = GRPCCompletionParser(serializationStrategy.errorDetailParser()) private var responseCompressionPool: CompressionPool? = null private var responseHeaders: Headers = emptyMap() + private var streamEmpty: Boolean = true override fun unaryFunction(): UnaryFunction { return UnaryFunction( @@ -66,30 +67,47 @@ internal class GRPCInterceptor( if (response.cause != null) { return@UnaryFunction response.clone(message = Buffer()) } + val headers = response.headers if (response.status != 200) { return@UnaryFunction response.clone( message = Buffer(), cause = ConnectException( code = Code.fromHTTPStatus(response.status), message = "unexpected status code: ${response.status}", + metadata = headers, + ), + ) + } + val contentType = headers[CONTENT_TYPE]?.first() ?: "" + if (!contentTypeIsExpectedGRPC(contentType, serializationStrategy.serializationName())) { + // If content-type looks like it could be a gRPC server's response, consider + // this an internal error. Otherwise, we infer a code from the HTTP status, + // which means a code of UNKNOWN since HTTP status is 200. + val code = if (contentTypeIsGRPC(contentType)) Code.INTERNAL_ERROR else Code.UNKNOWN + return@UnaryFunction response.clone( + message = Buffer(), + cause = ConnectException( + code = code, + message = "unexpected content-type: $contentType", + metadata = headers, ), ) } - val headers = response.headers var trailers = response.trailers - val completion = completionParser - .parse(headers, trailers) + val hasBody = !response.message.buffer.exhausted() + val completion = completionParser.parse(headers, hasBody, trailers) if (completion.trailersOnly) { trailers = headers // report the headers also as trailers } val exception = completion.toConnectExceptionOrNull(serializationStrategy) val message = if (exception == null) { - if (response.message.buffer.exhausted()) { + if (!hasBody) { return@UnaryFunction response.clone( message = Buffer(), cause = ConnectException( code = Code.UNIMPLEMENTED, message = "unary stream has no messages", + metadata = headers.plus(trailers), ), ) } @@ -105,6 +123,7 @@ internal class GRPCInterceptor( cause = ConnectException( code = Code.UNIMPLEMENTED, message = "unary stream has multiple messages", + metadata = headers.plus(trailers), ), ) } @@ -137,21 +156,28 @@ internal class GRPCInterceptor( streamResultFunction = { res -> res.fold( onHeaders = { result -> - val headers = result.headers - val completion = completionParser.parse(headers, emptyMap()) - if (completion.present) { + responseHeaders = result.headers + val contentType = responseHeaders[CONTENT_TYPE]?.first() ?: "" + if (!contentTypeIsExpectedGRPC(contentType, serializationStrategy.serializationName())) { + // If content-type looks like it could be a gRPC server's response, consider + // this an internal error. Otherwise, we infer a code from the HTTP status, + // which means a code of UNKNOWN since HTTP status is 200. + val code = if (contentTypeIsGRPC(contentType)) Code.INTERNAL_ERROR else Code.UNKNOWN StreamResult.Complete( - cause = completion.toConnectExceptionOrNull(serializationStrategy), - trailers = headers, + cause = ConnectException( + code = code, + message = "unexpected content-type: $contentType", + metadata = responseHeaders, + ), ) } else { - responseHeaders = headers responseCompressionPool = clientConfig - .compressionPool(headers[GRPC_ENCODING]?.first()) - StreamResult.Headers(headers) + .compressionPool(responseHeaders[GRPC_ENCODING]?.first()) + StreamResult.Headers(responseHeaders) } }, onMessage = { result -> + streamEmpty = false val (_, unpackedMessage) = Envelope.unpackWithHeaderByte( result.message, responseCompressionPool, @@ -164,7 +190,7 @@ internal class GRPCInterceptor( } val trailers = result.trailers val exception = completionParser - .parse(responseHeaders, trailers) + .parse(responseHeaders, !streamEmpty, trailers) .toConnectExceptionOrNull(serializationStrategy) StreamResult.Complete( cause = exception, @@ -190,3 +216,12 @@ internal class GRPCInterceptor( return headers } } + +internal fun contentTypeIsGRPC(contentType: String): Boolean { + return contentType == "application/grpc" || contentType.startsWith("application/grpc+") +} + +internal fun contentTypeIsExpectedGRPC(contentType: String, expectCodec: String): Boolean { + return (expectCodec == "proto" && contentType == "application/grpc") || + contentType == "application/grpc+$expectCodec" +} diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt index 72115d91..69236084 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt @@ -42,6 +42,7 @@ internal class GRPCWebInterceptor( private val completionParser = GRPCCompletionParser(serializationStrategy.errorDetailParser()) private var responseCompressionPool: CompressionPool? = null private var responseHeaders: Headers = emptyMap() + private var streamEmpty: Boolean = true override fun unaryFunction(): UnaryFunction { return UnaryFunction( @@ -72,16 +73,32 @@ internal class GRPCWebInterceptor( if (response.cause != null) { return@UnaryFunction response.clone(message = Buffer()) } + val headers = response.headers if (response.status != 200) { return@UnaryFunction response.clone( message = Buffer(), cause = ConnectException( code = Code.fromHTTPStatus(response.status), message = "unexpected status code: ${response.status}", + metadata = headers, + ), + ) + } + val contentType = headers[CONTENT_TYPE]?.first() ?: "" + if (!contentTypeIsExpectedGRPCWeb(contentType, serializationStrategy.serializationName())) { + // If content-type looks like it could be a gRPC server's response, consider + // this an internal error. Otherwise, we infer a code from the HTTP status, + // which means a code of UNKNOWN since HTTP status is 200. + val code = if (contentTypeIsGRPCWeb(contentType)) Code.INTERNAL_ERROR else Code.UNKNOWN + return@UnaryFunction response.clone( + message = Buffer(), + cause = ConnectException( + code = code, + message = "unexpected content-type: $contentType", + metadata = headers, ), ) } - val headers = response.headers val compressionPool = clientConfig.compressionPool(headers[GRPC_ENCODING]?.first()) // gRPC Web returns data in 2 chunks (either/both of which may be compressed): @@ -93,13 +110,14 @@ internal class GRPCWebInterceptor( if (response.message.exhausted()) { // There was no response body. Read status within the headers. var exception = completionParser - .parse(headers, emptyMap()) + .parse(headers, false, emptyMap()) .toConnectExceptionOrNull(serializationStrategy) if (exception == null) { // No response data and no error code? exception = ConnectException( code = Code.UNIMPLEMENTED, message = "unary stream has no messages", + metadata = headers, ) } response.clone( @@ -123,6 +141,7 @@ internal class GRPCWebInterceptor( cause = ConnectException( code = Code.INTERNAL_ERROR, message = "response did not include an end of stream message", + metadata = headers, ), ) } else { @@ -138,6 +157,7 @@ internal class GRPCWebInterceptor( cause = ConnectException( code = Code.UNIMPLEMENTED, message = "unary stream has multiple messages", + metadata = headers, ), ) } @@ -150,18 +170,20 @@ internal class GRPCWebInterceptor( cause = ConnectException( code = Code.INTERNAL_ERROR, message = "response stream contains data after end-of-stream message", + metadata = headers, ), ) } val finalTrailers = parseGrpcWebTrailer(trailerBuffer) var exception = completionParser - .parse(headers, finalTrailers) + .parse(headers, true, finalTrailers) .toConnectExceptionOrNull(serializationStrategy) if (exception == null && currentMessage == null) { // No response message, and trailers indicated no error? exception = ConnectException( code = Code.UNIMPLEMENTED, message = "unary stream has multiple messages", + metadata = headers, ) } response.clone( @@ -190,21 +212,28 @@ internal class GRPCWebInterceptor( streamResultFunction = { res -> val streamResult = res.fold( onHeaders = { result -> - val headers = result.headers - val completion = completionParser.parse(headers, emptyMap()) - if (completion.present) { + responseHeaders = result.headers + val contentType = responseHeaders[CONTENT_TYPE]?.first() ?: "" + if (!contentTypeIsExpectedGRPCWeb(contentType, serializationStrategy.serializationName())) { + // If content-type looks like it could be a gRPC server's response, consider + // this an internal error. Otherwise, we infer a code from the HTTP status, + // which means a code of UNKNOWN since HTTP status is 200. + val code = if (contentTypeIsGRPCWeb(contentType)) Code.INTERNAL_ERROR else Code.UNKNOWN StreamResult.Complete( - cause = completion.toConnectExceptionOrNull(serializationStrategy), - trailers = headers, + cause = ConnectException( + code = code, + message = "unexpected content-type: $contentType", + metadata = responseHeaders, + ), ) } else { - responseHeaders = headers responseCompressionPool = clientConfig - .compressionPool(headers[GRPC_ENCODING]?.first()) - StreamResult.Headers(headers) + .compressionPool(responseHeaders[GRPC_ENCODING]?.first()) + StreamResult.Headers(responseHeaders) } }, onMessage = { result -> + streamEmpty = false val (headerByte, unpackedMessage) = Envelope.unpackWithHeaderByte( result.message, responseCompressionPool, @@ -212,7 +241,7 @@ internal class GRPCWebInterceptor( if (headerByte.and(TRAILERS_BIT) == TRAILERS_BIT) { val streamTrailers = parseGrpcWebTrailer(unpackedMessage) val exception = completionParser - .parse(responseHeaders, streamTrailers) + .parse(responseHeaders, !streamEmpty, streamTrailers) .toConnectExceptionOrNull(serializationStrategy) StreamResult.Complete( cause = exception, @@ -265,3 +294,12 @@ internal class GRPCWebInterceptor( return trailers } } + +internal fun contentTypeIsGRPCWeb(contentType: String): Boolean { + return contentType == "application/grpc-web" || contentType.startsWith("application/grpc-web+") +} + +internal fun contentTypeIsExpectedGRPCWeb(contentType: String, expectCodec: String): Boolean { + return (expectCodec == "proto" && contentType == "application/grpc-web") || + contentType == "application/grpc-web+$expectCodec" +} diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt index 5e42fe9a..9cd88988 100644 --- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt +++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt @@ -18,6 +18,7 @@ import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.http.clone +import com.connectrpc.protocols.CONTENT_TYPE import com.connectrpc.protocols.Envelope import com.connectrpc.protocols.NetworkProtocol import okio.Buffer @@ -25,6 +26,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever import java.net.URL private val UNARY_METHOD_SPEC = MethodSpec( @@ -69,6 +71,7 @@ class InterceptorChainTest { ) unaryChain = protocolClientConfig.createInterceptorChain() streamingChain = protocolClientConfig.createStreamingInterceptorChain() + whenever(protocolClientConfig.serializationStrategy.serializationName()).thenReturn("encoding_type") } @Test @@ -98,7 +101,11 @@ class InterceptorChainTest { @Test fun lifo_stream_result() { - val streamResult = streamingChain.streamResultFunction(StreamResult.Headers(emptyMap())) as StreamResult.Headers + val streamResult = streamingChain.streamResultFunction( + StreamResult.Headers( + mapOf(CONTENT_TYPE to listOf("application/connect+encoding_type")), + ), + ) as StreamResult.Headers assertThat(streamResult.headers["id"]).containsExactly("4", "3", "2", "1") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index c291925b..597d542b 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -292,7 +292,7 @@ class ConnectInterceptorTest { HTTPResponse( // body contents override status code status = 503, - headers = emptyMap(), + headers = mapOf(CONTENT_TYPE to listOf("application/json; charset=utf-8")), message = Buffer().write(json.encodeUtf8()), trailers = emptyMap(), ), @@ -483,6 +483,7 @@ class ConnectInterceptorTest { StreamResult.Headers( headers = mapOf( "trailer-x-some-key" to listOf("some_value"), + CONTENT_TYPE to listOf("application/connect+encoding_type"), CONNECT_STREAMING_CONTENT_ENCODING to listOf("gzip"), ), ), @@ -493,6 +494,7 @@ class ConnectInterceptorTest { assertThat(headerResult.headers).isEqualTo( mapOf( "trailer-x-some-key" to listOf("some_value"), + CONTENT_TYPE to listOf("application/connect+encoding_type"), CONNECT_STREAMING_CONTENT_ENCODING to listOf("gzip"), ), ) @@ -536,6 +538,7 @@ class ConnectInterceptorTest { streamFunction.streamResultFunction( StreamResult.Headers( headers = mapOf( + CONTENT_TYPE to listOf("application/connect+encoding_type"), CONNECT_STREAMING_CONTENT_ENCODING to listOf("gzip"), ), ), diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCErrorDetailParserTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCErrorDetailParserTest.kt index 3b070155..ad6a81f5 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCErrorDetailParserTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCErrorDetailParserTest.kt @@ -22,6 +22,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.Test import org.mockito.kotlin.mock import org.mockito.kotlin.verify +import org.mockito.kotlin.verifyNoInteractions class GRPCErrorDetailParserTest { @@ -32,6 +33,7 @@ class GRPCErrorDetailParserTest { val parser = GRPCCompletionParser(errorDetailParser) val completion = parser.parse( headers = emptyMap(), + hasBody = false, trailers = mapOf( GRPC_STATUS_TRAILER to listOf("${Code.UNAUTHENTICATED.value}"), GRPC_MESSAGE_TRAILER to listOf("str"), @@ -49,26 +51,70 @@ class GRPCErrorDetailParserTest { val parser = GRPCCompletionParser(errorDetailParser) val completion = parser.parse( headers = emptyMap(), + hasBody = false, trailers = mapOf( GRPC_MESSAGE_TRAILER to listOf("str"), GRPC_STATUS_DETAILS_TRAILERS to listOf("data".encodeUtf8().base64()), ), ) assertThat(completion.present).isFalse() + assertThat(completion.code).isEqualTo(Code.UNKNOWN) + assertThat(completion.errorDetails).isEmpty() + verifyNoInteractions(errorDetailParser) } @Test - fun trailersWithoutMessage() { + fun trailersOnly() { val parser = GRPCCompletionParser(errorDetailParser) val completion = parser.parse( headers = mapOf( GRPC_STATUS_TRAILER to listOf("${Code.UNAUTHENTICATED.value}"), + GRPC_MESSAGE_TRAILER to listOf("str"), GRPC_STATUS_DETAILS_TRAILERS to listOf("data".encodeUtf8().base64()), ), + hasBody = false, trailers = emptyMap(), ) assertThat(completion.present).isTrue() assertThat(completion.code).isEqualTo(Code.UNAUTHENTICATED) + assertThat(completion.message).isEqualTo("str") + verify(errorDetailParser).parseDetails("data".commonAsUtf8ToByteArray()) + } + + @Test + fun trailersWithoutStatusIncorrectStatusInHeaders() { + val parser = GRPCCompletionParser(errorDetailParser) + val completion = parser.parse( + headers = mapOf( + GRPC_STATUS_TRAILER to listOf("${Code.UNAUTHENTICATED.value}"), + GRPC_MESSAGE_TRAILER to listOf("str"), + GRPC_STATUS_DETAILS_TRAILERS to listOf("data".encodeUtf8().base64()), + ), + // since there is a body, we don't look for a status in the headers + // because it a trailers-only response has no body and no trailers + hasBody = true, + trailers = emptyMap(), + ) + assertThat(completion.present).isFalse() + assertThat(completion.code).isEqualTo(Code.UNKNOWN) + assertThat(completion.errorDetails).isEmpty() + verifyNoInteractions(errorDetailParser) + } + + @Test + fun trailersWithoutMessage() { + val parser = GRPCCompletionParser(errorDetailParser) + val completion = parser.parse( + headers = emptyMap(), + hasBody = false, + trailers = mapOf( + GRPC_STATUS_TRAILER to listOf("${Code.UNAUTHENTICATED.value}"), + GRPC_STATUS_DETAILS_TRAILERS to listOf("data".encodeUtf8().base64()), + ), + ) + assertThat(completion.present).isTrue() + assertThat(completion.code).isEqualTo(Code.UNAUTHENTICATED) + verify(errorDetailParser).parseDetails("data".commonAsUtf8ToByteArray()) } @Test @@ -76,12 +122,15 @@ class GRPCErrorDetailParserTest { val parser = GRPCCompletionParser(errorDetailParser) val completion = parser.parse( headers = emptyMap(), + hasBody = false, trailers = mapOf( GRPC_STATUS_TRAILER to listOf("${Code.UNAUTHENTICATED.value}"), GRPC_MESSAGE_TRAILER to listOf("str"), ), ) assertThat(completion.present).isTrue() + assertThat(completion.message).isEqualTo("str") assertThat(completion.errorDetails).isEmpty() + verifyNoInteractions(errorDetailParser) } } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt index 98890a04..e820e5a2 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt @@ -185,7 +185,7 @@ class GRPCInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = emptyMap(), + headers = mapOf(CONTENT_TYPE to listOf("application/grpc+encoding_type")), message = envelopedMessage, trailers = mapOf( GRPC_STATUS_TRAILER to listOf("0"), @@ -209,7 +209,10 @@ class GRPCInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = mapOf(GRPC_ENCODING to listOf(GzipCompressionPool.name())), + headers = mapOf( + CONTENT_TYPE to listOf("application/grpc+encoding_type"), + GRPC_ENCODING to listOf(GzipCompressionPool.name()), + ), message = envelopedMessage, trailers = mapOf( GRPC_STATUS_TRAILER to listOf("0"), @@ -253,7 +256,7 @@ class GRPCInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = emptyMap(), + headers = mapOf(CONTENT_TYPE to listOf("application/grpc+encoding_type")), message = Buffer().write(json.encodeUtf8()), trailers = mapOf( GRPC_STATUS_TRAILER to listOf("${Code.RESOURCE_EXHAUSTED.value}"), @@ -423,7 +426,8 @@ class GRPCInterceptorTest { val result = streamFunction.streamResultFunction( StreamResult.Headers( - mapOf( + headers = mapOf( + CONTENT_TYPE to listOf("application/grpc+encoding_type"), "key" to listOf("value"), ), ), @@ -471,6 +475,7 @@ class GRPCInterceptorTest { streamFunction.streamResultFunction( StreamResult.Headers( headers = mapOf( + CONTENT_TYPE to listOf("application/grpc+encoding_type"), GRPC_ENCODING to listOf(GzipCompressionPool.name()), ), ), diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt index ecb8df0b..60a29ddc 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt @@ -183,7 +183,10 @@ class GRPCWebInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = mapOf(GRPC_ENCODING to listOf("gzip")), + headers = mapOf( + CONTENT_TYPE to listOf("application/grpc-web+encoding_type"), + GRPC_ENCODING to listOf("gzip"), + ), message = responseBody, trailers = emptyMap(), ), @@ -210,7 +213,10 @@ class GRPCWebInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = mapOf(GRPC_ENCODING to listOf(GzipCompressionPool.name())), + headers = mapOf( + CONTENT_TYPE to listOf("application/grpc-web+encoding_type"), + GRPC_ENCODING to listOf(GzipCompressionPool.name()), + ), message = responseBody, trailers = emptyMap(), ), @@ -232,6 +238,7 @@ class GRPCWebInterceptorTest { HTTPResponse( status = 200, headers = mapOf( + CONTENT_TYPE to listOf("application/grpc-web+encoding_type"), GRPC_STATUS_TRAILER to listOf("${Code.RESOURCE_EXHAUSTED.value}"), ), message = Buffer(), @@ -259,7 +266,7 @@ class GRPCWebInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = emptyMap(), + headers = mapOf(CONTENT_TYPE to listOf("application/grpc-web+encoding_type")), message = trailers, trailers = emptyMap(), ), @@ -290,7 +297,7 @@ class GRPCWebInterceptorTest { val response = unaryFunction.responseFunction( HTTPResponse( status = 200, - headers = emptyMap(), + headers = mapOf(CONTENT_TYPE to listOf("application/grpc-web+encoding_type")), message = responseBody, trailers = emptyMap(), ), @@ -452,6 +459,7 @@ class GRPCWebInterceptorTest { headers = mapOf( // Doesn't get passed as headers. "trailer-x-some-key" to listOf("some_value"), + CONTENT_TYPE to listOf("application/grpc-web+encoding_type"), GRPC_ENCODING to listOf("gzip"), ), ), @@ -498,6 +506,7 @@ class GRPCWebInterceptorTest { streamFunction.streamResultFunction( StreamResult.Headers( headers = mapOf( + CONTENT_TYPE to listOf("application/grpc-web+encoding_type"), GRPC_ENCODING to listOf("gzip"), ), ), diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt index bf6ab1ad..2010ce86 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt @@ -112,6 +112,7 @@ private class ResponseCallback( cause = ConnectException( code = Code.fromHTTPStatus(httpStatus), message = "unexpected HTTP status: $httpStatus ${response.originalMessage()}", + metadata = headers, ), ) onResult(finalResult)