diff --git a/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt index e70cfd99..fa579cce 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt @@ -37,15 +37,15 @@ interface CompressionPool { /** * Compress an outbound request message. - * @param buffer: The uncompressed request message. + * @param input: The uncompressed request message. * @return The compressed request message. */ - fun compress(buffer: Buffer): Buffer + fun compress(input: Buffer): Buffer /** * Decompress an inbound response message. - * @param buffer: The compressed response message. + * @param input: The compressed response message. * @return The uncompressed response message. */ - fun decompress(buffer: Buffer): Buffer + fun decompress(input: Buffer): Buffer } diff --git a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt index 74fee5af..86a5ee55 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt @@ -28,20 +28,23 @@ object GzipCompressionPool : CompressionPool { return "gzip" } - override fun compress(buffer: Buffer): Buffer { - val gzippedSink = Buffer() - GzipSink(gzippedSink).use { source -> - source.write(buffer, buffer.size) + override fun compress(input: Buffer): Buffer { + val result = Buffer() + GzipSink(result).use { gzippedSink -> + gzippedSink.write(input, input.size) } - return gzippedSink + return result } - override fun decompress(buffer: Buffer): Buffer { + override fun decompress(input: Buffer): Buffer { val result = Buffer() - if (buffer.size == 0L) return result + // We're lenient and will allow an empty payload to be + // interpreted as a compressed empty payload (even though + // it's missing the gzip format preamble/metadata). + if (input.size == 0L) return result - GzipSource(buffer).use { - while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) { + GzipSource(input).use { gzippedSource -> + while (gzippedSource.read(result, Int.MAX_VALUE.toLong()) != -1L) { // continue reading. } } diff --git a/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt b/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt index 34485af7..479c21f5 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt @@ -30,26 +30,23 @@ class Envelope { * @param compressionMinBytes The minimum bytes the source needs to be in order to be compressed. */ fun pack(source: Buffer, compressionPool: CompressionPool? = null, compressionMinBytes: Int? = null): Buffer { + val flags: Int + val payload: Buffer if (compressionMinBytes == null || source.size < compressionMinBytes || compressionPool == null ) { - return source.use { - val result = Buffer() - result.writeByte(0) - result.writeInt(source.buffer.size.toInt()) - result.writeAll(source) - result - } - } - return source.use { buffer -> - val result = Buffer() - result.writeByte(1) - val compressedBuffer = compressionPool.compress(buffer) - result.writeInt(compressedBuffer.size.toInt()) - result.writeAll(compressedBuffer) - result + flags = 0 + payload = source + } else { + flags = 1 + payload = compressionPool.compress(source) } + val result = Buffer() + result.writeByte(flags) + result.writeInt(payload.buffer.size.toInt()) + result.writeAll(payload) + return result } /** diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt index 6efc4623..68f8f8ac 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt @@ -157,19 +157,29 @@ internal class GRPCInterceptor( onCompletion = { result -> val trailers = result.trailers val completion = completionParser.parse(emptyMap(), trailers) + if (completion == null && result.cause != null) { + // let error result propagate + return@fold result + } + val exception: ConnectException? if (completion != null) { - val exception = completion.toConnectExceptionOrNull( + exception = completion.toConnectExceptionOrNull( serializationStrategy, result.cause, ) - StreamResult.Complete( - code = exception?.code ?: Code.OK, - cause = exception, - trailers = trailers, - ) } else { - result + exception = ConnectException( + code = Code.INTERNAL_ERROR, + errorDetailParser = serializationStrategy.errorDetailParser(), + message = "protocol error: status is missing from trailers", + metadata = trailers, + ) } + StreamResult.Complete( + code = exception?.code ?: Code.OK, + cause = exception, + trailers = trailers, + ) }, ) },