diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4a20e1d6..a124ce22 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -32,6 +32,7 @@ dokka-plugin = { module = "org.jetbrains.dokka:dokka-gradle-plugin", version.ref junit = { module = "junit:junit", version.ref = "junit" } kotlin-coroutines-android = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-android", version.ref = "coroutines" } kotlin-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "coroutines" } +kotlin-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "coroutines" } kotlin-jsr223 = { module = "org.jetbrains.kotlin:kotlin-scripting-jsr223", version.ref = "kotlin" } kotlin-plugin = { module = "org.jetbrains.kotlin:kotlin-gradle-plugin", version.ref = "kotlin" } kotlin-reflect = { module = "org.jetbrains.kotlin:kotlin-reflect", version.ref = "kotlin" } @@ -44,6 +45,7 @@ moshiKotlin = { module = "com.squareup.moshi:moshi-kotlin", version.ref = "moshi moshiKotlinCodegen = { module = "com.squareup.moshi:moshi-kotlin-codegen", version.ref = "moshi" } okhttp-core = { module = "com.squareup.okhttp3:okhttp", version.ref = "okhttp" } okhttp-tls = { module = "com.squareup.okhttp3:okhttp-tls", version.ref = "okhttp" } +okhttp-mockwebserver = { module = "com.squareup.okhttp3:mockwebserver", version.ref = "okhttp" } okio-core = { module = "com.squareup.okio:okio", version.ref = "okio" } protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", version.ref = "protobuf" } diff --git a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt index affcb68e..74fee5af 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt @@ -38,6 +38,8 @@ object GzipCompressionPool : CompressionPool { override fun decompress(buffer: Buffer): Buffer { val result = Buffer() + if (buffer.size == 0L) return result + GzipSource(buffer).use { while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) { // continue reading. diff --git a/library/src/test/kotlin/com/connectrpc/compression/GzipCompressionPoolTest.kt b/library/src/test/kotlin/com/connectrpc/compression/GzipCompressionPoolTest.kt index 99badacc..45ea3a17 100644 --- a/library/src/test/kotlin/com/connectrpc/compression/GzipCompressionPoolTest.kt +++ b/library/src/test/kotlin/com/connectrpc/compression/GzipCompressionPoolTest.kt @@ -43,4 +43,11 @@ class GzipCompressionPoolTest { val resultString = compressionPool.decompress(result).readUtf8() assertThat(resultString).isEqualTo("some_string") } + + @Test + fun emptyBufferGzipDecompression() { + val compressionPool = GzipCompressionPool + val resultString = compressionPool.decompress(Buffer()).readUtf8() + assertThat(resultString).isEqualTo("") + } } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 6f1982fa..212a3cd7 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -170,6 +170,34 @@ class ConnectInterceptorTest { assertThat(decompressed.readUtf8()).isEqualTo("message") } + @Test + fun compressedEmptyRequestMessage() { + val config = ProtocolClientConfig( + host = "https://connectrpc.com", + serializationStrategy = serializationStrategy, + requestCompression = RequestCompression(1, GzipCompressionPool), + compressionPools = listOf(GzipCompressionPool), + ) + val connectInterceptor = ConnectInterceptor(config) + val unaryFunction = connectInterceptor.unaryFunction() + + val request = unaryFunction.requestFunction( + HTTPRequest( + url = URL(config.host), + contentType = "content_type", + headers = emptyMap(), + message = "".commonAsUtf8ToByteArray(), + methodSpec = MethodSpec( + path = "", + requestClass = Any::class, + responseClass = Any::class, + ), + ), + ) + val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + assertThat(decompressed.readUtf8()).isEqualTo("") + } + @Test fun uncompressedResponseMessage() { val config = ProtocolClientConfig( @@ -214,6 +242,28 @@ class ConnectInterceptorTest { assertThat(response.message.readUtf8()).isEqualTo("message") } + @Test + fun compressedEmptyResponseMessage() { + val config = ProtocolClientConfig( + host = "https://connectrpc.com", + serializationStrategy = serializationStrategy, + compressionPools = listOf(GzipCompressionPool), + ) + val connectInterceptor = ConnectInterceptor(config) + val unaryFunction = connectInterceptor.unaryFunction() + + val response = unaryFunction.responseFunction( + HTTPResponse( + code = Code.OK, + headers = mapOf(CONTENT_ENCODING to listOf(GzipCompressionPool.name())), + message = Buffer(), + trailers = emptyMap(), + tracingInfo = null, + ), + ) + assertThat(response.message.readUtf8()).isEqualTo("") + } + @Test fun responseError() { val config = ProtocolClientConfig( diff --git a/okhttp/build.gradle.kts b/okhttp/build.gradle.kts index 4b96f1fc..e39dcd76 100644 --- a/okhttp/build.gradle.kts +++ b/okhttp/build.gradle.kts @@ -16,6 +16,12 @@ dependencies { implementation(libs.kotlin.coroutines.core) api(project(":library")) + + testImplementation(libs.assertj) + testImplementation(libs.okhttp.mockwebserver) + testImplementation(libs.kotlin.coroutines.test) + testImplementation(project(":extensions:google-java")) + testImplementation(project(":examples:generated-google-java")) } mavenPublishing { diff --git a/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerRule.kt b/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerRule.kt new file mode 100644 index 00000000..4efb8e71 --- /dev/null +++ b/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerRule.kt @@ -0,0 +1,38 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc.okhttp + +import okhttp3.mockwebserver.MockWebServer +import org.junit.rules.TestWatcher +import org.junit.runner.Description + +class MockWebServerRule( + private val port: Int = 0, +) : TestWatcher() { + + lateinit var server: MockWebServer + private set + + override fun starting(description: Description) { + super.starting(description) + server = MockWebServer() + server.start(port) + } + + override fun finished(description: Description) { + super.finished(description) + server.shutdown() + } +} diff --git a/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt b/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt new file mode 100644 index 00000000..5651df42 --- /dev/null +++ b/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt @@ -0,0 +1,74 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc.okhttp + +import com.connectrpc.Code +import com.connectrpc.ProtocolClientConfig +import com.connectrpc.RequestCompression +import com.connectrpc.compression.GzipCompressionPool +import com.connectrpc.eliza.v1.ElizaServiceClient +import com.connectrpc.eliza.v1.sayRequest +import com.connectrpc.extensions.GoogleJavaProtobufStrategy +import com.connectrpc.impl.ProtocolClient +import com.connectrpc.protocols.NetworkProtocol +import kotlinx.coroutines.test.runTest +import okhttp3.OkHttpClient +import okhttp3.Protocol +import okhttp3.mockwebserver.MockResponse +import org.assertj.core.api.Assertions.assertThat +import org.junit.Rule +import org.junit.Test + +class MockWebServerTests { + + @get:Rule val mockWebServerRule = MockWebServerRule() + + @Test + fun `compressed empty failure response is parsed correctly`() = runTest { + mockWebServerRule.server.enqueue( + MockResponse().apply { + addHeader("accept-encoding", "gzip") + addHeader("content-encoding", "gzip") + setBody("{}") + setResponseCode(401) + }, + ) + + val host = mockWebServerRule.server.url("/") + + val protocolClient = ProtocolClient( + ConnectOkHttpClient( + OkHttpClient.Builder() + .protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)) + .build(), + ), + ProtocolClientConfig( + host = host.toString(), + serializationStrategy = GoogleJavaProtobufStrategy(), + networkProtocol = NetworkProtocol.CONNECT, + requestCompression = RequestCompression(0, GzipCompressionPool), + compressionPools = listOf(GzipCompressionPool), + ), + ) + + val response = ElizaServiceClient(protocolClient).say(sayRequest { sentence = "hello" }) + + mockWebServerRule.server.takeRequest().apply { + assertThat(path).isEqualTo("/connectrpc.eliza.v1.ElizaService/Say") + } + + assertThat(response.code).isEqualTo(Code.UNKNOWN) + } +}