diff --git a/okhttp/src/main/kotlin/okhttp3/internal/ws/MessageInflater.kt b/okhttp/src/main/kotlin/okhttp3/internal/ws/MessageInflater.kt index 382a2a4d9e51..1bcf28fce93b 100644 --- a/okhttp/src/main/kotlin/okhttp3/internal/ws/MessageInflater.kt +++ b/okhttp/src/main/kotlin/okhttp3/internal/ws/MessageInflater.kt @@ -28,19 +28,22 @@ class MessageInflater( ) : Closeable { private val deflatedBytes = Buffer() - private val inflater = - Inflater( - // nowrap (omits zlib header): - true, - ) - - private val inflaterSource = InflaterSource(deflatedBytes, inflater) + // Lazily-created. + private var inflater: Inflater? = null + private var inflaterSource: InflaterSource? = null /** Inflates [buffer] in place as described in RFC 7692 section 7.2.2. */ @Throws(IOException::class) fun inflate(buffer: Buffer) { require(deflatedBytes.size == 0L) + val inflater = + this.inflater + ?: Inflater(true).also { this.inflater = it } + val inflaterSource = + this.inflaterSource + ?: InflaterSource(deflatedBytes, inflater).also { this.inflaterSource = it } + if (noContextTakeover) { inflater.reset() } @@ -55,8 +58,21 @@ class MessageInflater( do { inflaterSource.readOrInflate(buffer, Long.MAX_VALUE) } while (inflater.bytesRead < totalBytesToRead && !inflater.finished()) + + // The inflater data was self-terminated and there's unexpected trailing data. Tear it all down + // so we don't leak that data into the input of the next message. + if (inflater.bytesRead < totalBytesToRead) { + deflatedBytes.clear() + inflaterSource.close() + this.inflaterSource = null + this.inflater = null + } } @Throws(IOException::class) - override fun close() = inflaterSource.close() + override fun close() { + inflaterSource?.close() + inflaterSource = null + inflater = null + } } diff --git a/okhttp/src/test/java/okhttp3/internal/ws/MessageDeflaterInflaterTest.kt b/okhttp/src/test/java/okhttp3/internal/ws/MessageDeflaterInflaterTest.kt index b9d444268d1b..3c9cf2cdd974 100644 --- a/okhttp/src/test/java/okhttp3/internal/ws/MessageDeflaterInflaterTest.kt +++ b/okhttp/src/test/java/okhttp3/internal/ws/MessageDeflaterInflaterTest.kt @@ -19,12 +19,15 @@ import assertk.assertThat import assertk.assertions.isEqualTo import assertk.assertions.isLessThan import java.io.EOFException +import java.util.zip.Deflater import kotlin.test.assertFailsWith import okhttp3.TestUtil.fragmentBuffer import okio.Buffer import okio.ByteString import okio.ByteString.Companion.decodeHex import okio.ByteString.Companion.encodeUtf8 +import okio.DeflaterSink +import okio.use import org.junit.jupiter.api.Test internal class MessageDeflaterInflaterTest { @@ -136,6 +139,41 @@ internal class MessageDeflaterInflaterTest { assertThat(buffer.readUtf8()).isEqualTo("Hello inflation!") } + /** + * It's possible a self-terminating deflated message will contain trailing data that won't be + * processed during inflation. If this happens, we need to either reject the message or discard + * the unreachable data. We choose to discard it! + * + * In practice this could happen if the encoder doesn't strip the [0x00, 0x00, 0xff, 0xff] suffix + * and that ends up repeated. + * + * https://github.com/square/okhttp/issues/8551 + */ + @Test + fun `deflated data has too many bytes`() { + val inflater = MessageInflater(true) + val buffer = Buffer() + + val message1 = "hello".encodeUtf8() + val message2 = "hello 2".encodeUtf8() + + DeflaterSink(buffer, Deflater(Deflater.DEFAULT_COMPRESSION, true)).use { sink -> + sink.write(Buffer().write(message1), message1.size.toLong()) + } + buffer.writeByte(0x00) + // Trailing data. We use the Okio segment size to make sure it's still in the input buffer. + buffer.write(ByteArray(8192)) + inflater.inflate(buffer) + assertThat(buffer.readByteString()).isEqualTo(message1) + + DeflaterSink(buffer, Deflater(Deflater.DEFAULT_COMPRESSION, true)).use { sink -> + sink.write(Buffer().write(message2), message2.size.toLong()) + } + buffer.writeByte(0x00) + inflater.inflate(buffer) + assertThat(buffer.readByteString()).isEqualTo(message2) + } + private fun MessageDeflater.deflate(byteString: ByteString): ByteString { val buffer = Buffer() buffer.write(byteString)