From 82f8988a8d75c9b34cef85c79b2f10c769a437f1 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Mon, 1 Apr 2024 18:40:23 -0400 Subject: [PATCH] some fixes noticed when testing against 'main' of conformance --- .../protocols/ConnectInterceptor.kt | 11 +++--- .../connectrpc/okhttp/ConnectOkHttpClient.kt | 29 +++++++++++---- .../com/connectrpc/okhttp/OkHttpStream.kt | 35 ++++--------------- 3 files changed, 35 insertions(+), 40 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index 9af725fb..596aff67 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -53,6 +53,7 @@ internal class ConnectInterceptor( private val moshi = Moshi.Builder().build() private val serializationStrategy = clientConfig.serializationStrategy private var responseCompressionPool: CompressionPool? = null + private var responseHeaders: Headers = emptyMap() override fun unaryFunction(): UnaryFunction { return UnaryFunction( @@ -159,8 +160,7 @@ internal class ConnectInterceptor( streamResultFunction = { res -> val streamResult: StreamResult = res.fold( onHeaders = { result -> - val responseHeaders = - result.headers.filter { entry -> !entry.key.startsWith("trailer-") } + responseHeaders = result.headers responseCompressionPool = clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first()) StreamResult.Headers(responseHeaders) @@ -171,7 +171,7 @@ internal class ConnectInterceptor( responseCompressionPool, ) if (headerByte.and(TRAILERS_BIT) == TRAILERS_BIT) { - parseConnectEndStream(unpackedMessage) + parseConnectEndStream(responseHeaders, unpackedMessage) } else { StreamResult.Message(unpackedMessage) } @@ -211,7 +211,7 @@ internal class ConnectInterceptor( ) } - private fun parseConnectEndStream(source: Buffer): StreamResult.Complete { + private fun parseConnectEndStream(headers: Headers, source: Buffer): StreamResult.Complete { val adapter = moshi.adapter(EndStreamResponseJSON::class.java).nonNull() return source.use { bufferedSource -> val errorJSON = bufferedSource.readUtf8() @@ -234,11 +234,12 @@ internal class ConnectInterceptor( cause = ConnectException( code = code, message = endStreamResponseJSON.error.message, - metadata = metadata.orEmpty(), + metadata = headers.plus(metadata.orEmpty()), ).withErrorDetails( serializationStrategy.errorDetailParser(), parseErrorDetails(endStreamResponseJSON.error), ), + trailers = metadata.orEmpty(), ) } } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index abbf45e1..2e66f3ec 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -17,6 +17,7 @@ package com.connectrpc.okhttp import com.connectrpc.Code import com.connectrpc.ConnectException import com.connectrpc.StreamResult +import com.connectrpc.asConnectException import com.connectrpc.http.Cancelable import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest @@ -147,17 +148,24 @@ class ConnectOkHttpClient @JvmOverloads constructor( override fun onResponse(call: Call, response: Response) { // Unary requests will need to read the entire body to access trailers. - val responseBuffer = response.body?.source()?.use { bufferedSource -> - val buffer = Buffer() - buffer.writeAll(bufferedSource) - buffer + var responseBuffer: Buffer? = null + var connEx: ConnectException? = null + try { + responseBuffer = response.body?.source()?.use { bufferedSource -> + val buffer = Buffer() + buffer.writeAll(bufferedSource) + buffer + } + } catch (ex: Throwable) { + connEx = asConnectException(ex, codeFromException(call.isCanceled(), ex)) } onResult( HTTPResponse( status = response.originalCode(), headers = response.headers.toLowerCaseKeysMultiMap(), message = responseBuffer ?: Buffer(), - trailers = response.trailers().toLowerCaseKeysMultiMap(), + trailers = response.safeTrailers(), + cause = connEx, ), ) } @@ -197,7 +205,7 @@ internal fun Headers.toLowerCaseKeysMultiMap(): Map> { ) } -internal fun codeFromException(callCanceled: Boolean, e: Exception): Code { +internal fun codeFromException(callCanceled: Boolean, e: Throwable): Code { return if ((e is InterruptedIOException && e.message == "timeout") || e is SocketTimeoutException ) { @@ -232,3 +240,12 @@ fun Response.originalMessage(): String { message } } + +internal fun Response.safeTrailers(): Map> { + return try { + trailers().toLowerCaseKeysMultiMap() + } catch (_: Throwable) { + // Trailers not available or something else went wrong... + emptyMap() + } +} diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt index 67299311..bf6ab1ad 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt @@ -108,7 +108,7 @@ private class ResponseCallback( if (httpStatus != 200) { // TODO: This is not quite exercised yet. Validate if this is exercised in another test case. val finalResult = StreamResult.Complete( - trailers = response.safeTrailers() ?: emptyMap(), + trailers = response.safeTrailers(), cause = ConnectException( code = Code.fromHTTPStatus(httpStatus), message = "unexpected HTTP status: $httpStatus ${response.originalMessage()}", @@ -119,7 +119,7 @@ private class ResponseCallback( } response.use { resp -> resp.body!!.source().use { sourceBuffer -> - var exception: Exception? = null + var connEx: ConnectException? = null try { while (!sourceBuffer.exhausted()) { val buffer = readStreamElement(sourceBuffer) @@ -128,18 +128,14 @@ private class ResponseCallback( ) onResult(streamResult) } - } catch (e: Exception) { - exception = e + } catch (ex: Exception) { + connEx = asConnectException(ex, codeFromException(call.isCanceled(), ex)) } finally { // If trailers are not yet communicated. // This is the final chance to notify trailers to the consumer. - val connectEx = when (exception) { - null -> null - else -> asConnectException(exception, codeFromException(call.isCanceled(), exception)) - } val finalResult = StreamResult.Complete( - trailers = response.safeTrailers() ?: emptyMap(), - cause = connectEx, + trailers = response.safeTrailers(), + cause = connEx, ) onResult(finalResult) } @@ -148,25 +144,6 @@ private class ResponseCallback( } } - private fun Response.safeTrailers(): Map>? { - try { - if (body?.source()?.exhausted() == false) { - // Assuming this means that trailers are not available. - // Returning null to signal trailers are "missing". - return null - } - } catch (e: Exception) { - return null - } - - return try { - trailers().toLowerCaseKeysMultiMap() - } catch (_: Throwable) { - // Something went terribly wrong. - emptyMap() - } - } - /** * Helps with reading and framing OkHttp responses into Buffers. *