From 77ae245c0adc526684db3d79ed4cfddcb2ab99c9 Mon Sep 17 00:00:00 2001 From: Josh Humphries <2035234+jhump@users.noreply.github.com> Date: Mon, 29 Jan 2024 20:39:29 -0500 Subject: [PATCH] cleanly separate HTTPRequest (base type, applies to streams and unary) and UnaryHTTPRequest (only for unary RPCs, includes request message contents) --- .../main/kotlin/com/connectrpc/Interceptor.kt | 3 +- .../connectrpc/http/HTTPClientInterface.kt | 2 +- .../kotlin/com/connectrpc/http/HTTPRequest.kt | 108 ++++++++++++------ .../com/connectrpc/impl/ProtocolClient.kt | 5 +- .../protocols/ConnectInterceptor.kt | 19 ++- .../connectrpc/protocols/GRPCInterceptor.kt | 12 +- .../protocols/GRPCWebInterceptor.kt | 12 +- .../com/connectrpc/InterceptorChainTest.kt | 22 +--- .../com/connectrpc/impl/ProtocolClientTest.kt | 15 +-- .../protocols/ConnectInterceptorTest.kt | 43 +++---- .../protocols/GRPCInterceptorTest.kt | 20 ++-- .../protocols/GRPCWebInterceptorTest.kt | 20 ++-- .../connectrpc/okhttp/ConnectOkHttpClient.kt | 25 +++- 13 files changed, 173 insertions(+), 133 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/Interceptor.kt b/library/src/main/kotlin/com/connectrpc/Interceptor.kt index f9bbf441..d47bddac 100644 --- a/library/src/main/kotlin/com/connectrpc/Interceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/Interceptor.kt @@ -16,6 +16,7 @@ package com.connectrpc import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.UnaryHTTPRequest import okio.Buffer /** @@ -52,7 +53,7 @@ interface Interceptor { } class UnaryFunction( - val requestFunction: (HTTPRequest) -> HTTPRequest = { it }, + val requestFunction: (UnaryHTTPRequest) -> UnaryHTTPRequest = { it }, val responseFunction: (HTTPResponse) -> HTTPResponse = { it }, ) diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt index bbc5a77f..0ae96877 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt @@ -33,7 +33,7 @@ interface HTTPClientInterface { * * @return A function to cancel the underlying network call. */ - fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable + fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable /** * Initialize a new HTTP stream. diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt index 60002b3d..f81572c2 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt @@ -16,6 +16,7 @@ package com.connectrpc.http import com.connectrpc.Headers import com.connectrpc.MethodSpec +import okio.Buffer import java.net.URL internal object HTTPMethod { @@ -24,50 +25,89 @@ internal object HTTPMethod { } /** - * HTTP request used for sending primitive data to the server. + * HTTP request used to initiate RPCs. */ -class HTTPRequest internal constructor( +open class HTTPRequest internal constructor( // The URL for the request. val url: URL, // Value to assign to the `content-type` header. val contentType: String, // Additional outbound headers for the request. val headers: Headers, - // Body data to send with the request. - val message: ByteArray? = null, // The method spec associated with the request. val methodSpec: MethodSpec<*, *>, // HTTP method to use with the request. // Almost always POST, but side effect free unary RPCs may be made with GET. val httpMethod: String = HTTPMethod.POST, -) { - /** - * Clones the [HTTPRequest] with override values. - * - * Intended to make mutations for [HTTPRequest] safe for - * [com.connectrpc.Interceptor] implementation. - */ - fun clone( - // The URL for the request. - url: URL = this.url, - // Value to assign to the `content-type` header. - contentType: String = this.contentType, - // Additional outbound headers for the request. - headers: Headers = this.headers, - // Body data to send with the request. - message: ByteArray? = this.message, - // The method spec associated with the request. - methodSpec: MethodSpec<*, *> = this.methodSpec, - // The HTTP method to use with the request. - httpMethod: String = this.httpMethod, - ): HTTPRequest { - return HTTPRequest( - url, - contentType, - headers, - message, - methodSpec, - httpMethod, - ) - } +) + +/** + * Clones the [HTTPRequest] with override values. + * + * Intended to make mutations for [HTTPRequest] safe for + * [com.connectrpc.Interceptor] implementation. + */ +fun HTTPRequest.clone( + // The URL for the request. + url: URL = this.url, + // Value to assign to the `content-type` header. + contentType: String = this.contentType, + // Additional outbound headers for the request. + headers: Headers = this.headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *> = this.methodSpec, + // The HTTP method to use with the request. + httpMethod: String = this.httpMethod, +): HTTPRequest { + return HTTPRequest( + url, + contentType, + headers, + methodSpec, + httpMethod, + ) +} + +/** + * HTTP request used to initiate unary RPCs. In addition + * to RPC metadata, this also includes the request data. + */ +class UnaryHTTPRequest( + // The URL for the request. + url: URL, + // Value to assign to the `content-type` header. + contentType: String, + // Additional outbound headers for the request. + headers: Headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *>, + // Body data for the request. + val message: Buffer, + // HTTP method to use with the request. + // Almost always POST, but side effect free unary RPCs may be made with GET. + httpMethod: String = HTTPMethod.POST, +) : HTTPRequest(url, contentType, headers, methodSpec, httpMethod) + +fun UnaryHTTPRequest.clone( + // The URL for the request. + url: URL = this.url, + // Value to assign to the `content-type` header. + contentType: String = this.contentType, + // Additional outbound headers for the request. + headers: Headers = this.headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *> = this.methodSpec, + // Body data for the request. + message: Buffer = this.message, + // The HTTP method to use with the request. + httpMethod: String = this.httpMethod, +): UnaryHTTPRequest { + return UnaryHTTPRequest( + url, + contentType, + headers, + methodSpec, + message, + httpMethod, + ) } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 45eb43d5..07807730 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -30,6 +30,7 @@ import com.connectrpc.UnaryBlockingCall import com.connectrpc.http.Cancelable import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest +import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.http.transform import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred @@ -79,12 +80,12 @@ class ProtocolClient( } else { requestCodec.serialize(request) } - val unaryRequest = HTTPRequest( + val unaryRequest = UnaryHTTPRequest( url = urlFromMethodSpec(methodSpec), contentType = "application/${requestCodec.encodingName()}", headers = headers, - message = requestMessage.readByteArray(), methodSpec = methodSpec, + message = requestMessage, ) val unaryFunc = config.createInterceptorChain() val finalRequest = unaryFunc.requestFunction(unaryRequest) diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index c1a032bf..b32927e3 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -32,6 +32,8 @@ import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.clone import com.connectrpc.toLowercase import com.squareup.moshi.Moshi import okio.Buffer @@ -67,15 +69,11 @@ internal class ConnectInterceptor( requestHeaders[USER_AGENT] = listOf("connect-kotlin/${ConnectConstants.VERSION}") } val requestCompression = clientConfig.requestCompression - val requestMessage = Buffer() - if (request.message != null) { - requestMessage.write(request.message) - } - val finalRequestBody = if (requestCompression?.shouldCompress(requestMessage) == true) { + val finalRequestBody = if (requestCompression?.shouldCompress(request.message) == true) { requestHeaders.put(CONTENT_ENCODING, listOf(requestCompression.compressionPool.name())) - requestCompression.compressionPool.compress(requestMessage) + requestCompression.compressionPool.compress(request.message) } else { - requestMessage + request.message } if (shouldUseGETRequest(request, finalRequestBody)) { constructGETRequest(request, finalRequestBody, requestCompression) @@ -84,8 +82,8 @@ internal class ConnectInterceptor( url = request.url, contentType = request.contentType, headers = requestHeaders, - message = finalRequestBody.readByteArray(), methodSpec = request.methodSpec, + message = finalRequestBody, ) } }, @@ -153,7 +151,6 @@ internal class ConnectInterceptor( url = request.url, contentType = request.contentType, headers = requestHeaders, - message = request.message, methodSpec = request.methodSpec, ) }, @@ -196,10 +193,10 @@ internal class ConnectInterceptor( } private fun constructGETRequest( - request: HTTPRequest, + request: UnaryHTTPRequest, finalRequestBody: Buffer, requestCompression: RequestCompression?, - ): HTTPRequest { + ): UnaryHTTPRequest { val serializationStrategy = clientConfig.serializationStrategy val requestCodec = serializationStrategy.codec(request.methodSpec.requestClass) val url = getUrlFromMethodSpec( diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt index 5ac12fb5..6efc4623 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt @@ -24,6 +24,7 @@ import com.connectrpc.StreamResult import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.clone import okio.Buffer /** @@ -46,16 +47,10 @@ internal class GRPCInterceptor( requestHeaders[GRPC_ACCEPT_ENCODING] = clientConfig.compressionPools() .map { compressionPool -> compressionPool.name() } } - val requestMessage = Buffer().use { buffer -> - if (request.message != null) { - buffer.write(request.message) - } - buffer - } val requestCompression = clientConfig.requestCompression // GRPC unary payloads are enveloped. val envelopedMessage = Envelope.pack( - requestMessage, + request.message, requestCompression?.compressionPool, requestCompression?.minBytes, ) @@ -64,7 +59,7 @@ internal class GRPCInterceptor( // The underlying content type is overridden here. contentType = "application/grpc+${serializationStrategy.serializationName()}", headers = requestHeaders.withGRPCRequestHeaders(), - message = envelopedMessage.readByteArray(), + message = envelopedMessage, ) }, responseFunction = { response -> @@ -128,7 +123,6 @@ internal class GRPCInterceptor( url = request.url, contentType = "application/grpc+${serializationStrategy.serializationName()}", headers = request.headers.withGRPCRequestHeaders(), - message = request.message, ) }, requestBodyFunction = { buffer -> diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt index 6c1ebb99..72cba93c 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt @@ -25,6 +25,7 @@ import com.connectrpc.Trailers import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.clone import okio.Buffer internal const val TRAILERS_BIT = 0b10000000 @@ -50,15 +51,9 @@ internal class GRPCWebInterceptor( .map { compressionPool -> compressionPool.name() } } val requestCompressionPool = clientConfig.requestCompression - val requestMessage = Buffer().use { buffer -> - if (request.message != null) { - buffer.write(request.message) - } - buffer - } // GRPC unary payloads are enveloped. val envelopedMessage = Envelope.pack( - requestMessage, + request.message, requestCompressionPool?.compressionPool, requestCompressionPool?.minBytes, ) @@ -68,7 +63,7 @@ internal class GRPCWebInterceptor( // The underlying content type is overridden here. contentType = "application/grpc-web+${serializationStrategy.serializationName()}", headers = requestHeaders.withGRPCRequestHeaders(), - message = envelopedMessage.readByteArray(), + message = envelopedMessage, ) }, responseFunction = { response -> @@ -175,7 +170,6 @@ internal class GRPCWebInterceptor( url = request.url, contentType = "application/grpc-web+${serializationStrategy.serializationName()}", headers = request.headers.withGRPCRequestHeaders(), - message = request.message, ) }, requestBodyFunction = { buffer -> diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt index b9c47cb5..7d3671ef 100644 --- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt +++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt @@ -17,6 +17,8 @@ package com.connectrpc import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.clone import com.connectrpc.protocols.Envelope import com.connectrpc.protocols.NetworkProtocol import okio.Buffer @@ -72,7 +74,7 @@ class InterceptorChainTest { @Test fun fifo_request_unary() { - val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC)) + val response = unaryChain.requestFunction(UnaryHTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), UNARY_METHOD_SPEC, Buffer())) assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -84,7 +86,7 @@ class InterceptorChainTest { @Test fun fifo_request_stream() { - val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC)) + val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), STREAM_METHOD_SPEC)) assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -115,13 +117,7 @@ class InterceptorChainTest { val sequence = headers.get("id")?.toMutableList() ?: mutableListOf() sequence.add(id) headers.put("id", sequence) - HTTPRequest( - it.url, - it.contentType, - headers, - it.message, - UNARY_METHOD_SPEC, - ) + it.clone(headers = headers) }, responseFunction = { val headers = it.headers.toMutableMap() @@ -147,13 +143,7 @@ class InterceptorChainTest { val sequence = headers.get("id")?.toMutableList() ?: mutableListOf() sequence.add(id) headers.put("id", sequence) - HTTPRequest( - it.url, - it.contentType, - headers, - it.message, - STREAM_METHOD_SPEC, - ) + it.clone(headers = headers) }, requestBodyFunction = { it.writeString(id, Charsets.UTF_8) diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 16f96164..96c505e5 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -21,6 +21,7 @@ import com.connectrpc.SerializationStrategy import com.connectrpc.StreamType import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest +import com.connectrpc.http.UnaryHTTPRequest import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -50,7 +51,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) { _ -> } - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -67,7 +68,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) { _ -> } - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -84,7 +85,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.BIDI), ) - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).stream(captor.capture(), true, any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -119,7 +120,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -135,7 +136,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -151,7 +152,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } @@ -167,7 +168,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 97f99385..f917524a 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -30,11 +30,10 @@ import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.squareup.moshi.Moshi import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray -import okio.internal.commonToUtf8String import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -71,10 +70,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -102,10 +102,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("User-Agent" to listOf("custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -129,11 +130,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -142,7 +143,7 @@ class ConnectInterceptorTest { ), ), ) - assertThat(request.message!!.commonToUtf8String()).isEqualTo("message") + assertThat(request.message.readUtf8()).isEqualTo("message") } @Test @@ -157,11 +158,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -170,7 +171,7 @@ class ConnectInterceptorTest { ), ), ) - val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + val decompressed = GzipCompressionPool.decompress(request.message) assertThat(decompressed.readUtf8()).isEqualTo("message") } @@ -186,11 +187,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "".commonAsUtf8ToByteArray(), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -199,7 +200,7 @@ class ConnectInterceptorTest { ), ), ) - val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + val decompressed = GzipCompressionPool.decompress(request.message) assertThat(decompressed.readUtf8()).isEqualTo("") } @@ -679,11 +680,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -716,11 +717,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -745,11 +746,11 @@ class ConnectInterceptorTest { val connectInterceptor = ConnectInterceptor(config) val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -779,11 +780,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt index 54d34239..a4f121fb 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt @@ -28,10 +28,10 @@ import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.squareup.moshi.Moshi import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -65,10 +65,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -93,10 +94,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value"), "User-Agent" to listOf("my-custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -120,11 +122,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -133,7 +135,7 @@ class GRPCInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) assertThat(message.readUtf8()).isEqualTo("message") } @@ -149,11 +151,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -162,7 +164,7 @@ class GRPCInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) val decompressed = GzipCompressionPool.decompress(message) assertThat(decompressed.readUtf8()).isEqualTo("message") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt index 0d8e78ba..1f9da1d9 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt @@ -27,9 +27,9 @@ import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -61,10 +61,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -90,10 +91,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = mapOf("X-User-Agent" to listOf("custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -117,11 +119,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -130,7 +132,7 @@ class GRPCWebInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) assertThat(message.readUtf8()).isEqualTo("message") } @@ -146,11 +148,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -159,7 +161,7 @@ class GRPCWebInterceptorTest { ), ), ) - val (_, decompressed) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!), GzipCompressionPool) + val (_, decompressed) = Envelope.unpackWithHeaderByte(request.message, GzipCompressionPool) assertThat(decompressed.readUtf8()).isEqualTo("message") } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index 214dc24f..5d566c66 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -23,6 +23,7 @@ import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.Stream import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.protocols.CONNECT_PROTOCOL_VERSION_KEY import com.connectrpc.protocols.CONNECT_PROTOCOL_VERSION_VALUE import com.connectrpc.protocols.GETConstants @@ -33,10 +34,11 @@ import okhttp3.Interceptor import okhttp3.MediaType.Companion.toMediaType import okhttp3.OkHttpClient import okhttp3.Request -import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.RequestBody import okhttp3.Response import okhttp3.internal.http.HttpMethod import okio.Buffer +import okio.BufferedSink import java.io.IOException import java.io.InterruptedIOException import java.net.SocketTimeoutException @@ -86,16 +88,31 @@ class ConnectOkHttpClient @JvmOverloads constructor( } } - override fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable { + override fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable { val builder = Request.Builder() for (entry in request.headers) { for (values in entry.value) { builder.addHeader(entry.key, values) } } - val content = request.message ?: ByteArray(0) + val content = request.message val method = request.httpMethod - val requestBody = if (HttpMethod.requiresRequestBody(method)) content.toRequestBody(request.contentType.toMediaType()) else null + val requestBody = if (HttpMethod.requiresRequestBody(method)) { + object : RequestBody() { + override fun contentType() = request.contentType.toMediaType() + override fun contentLength() = content.size + override fun writeTo(sink: BufferedSink) { + // We make a copy so that this body is not "one shot", + // meaning that the okhttp library may automatically + // retry the request under certain conditions. If we + // didn't copy it, then reading it here would consume + // it and then a retry would only see an empty body. + content.copy().readAll(sink) + } + } + } else { + null + } val callRequest = builder .url(request.url) .method(method, requestBody)