Skip to content

Commit

Permalink
Add support for accessing headers and trailers to streaming calls (#171)
Browse files Browse the repository at this point in the history
This is done via a `CompletableDeferred` which is completed when the
headers/trailers are received.
  • Loading branch information
jhump authored Dec 8, 2023
1 parent de27a21 commit 92cc2a7
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ import java.util.concurrent.TimeUnit

@RunWith(Parameterized::class)
class ConformanceTest(
protocol: NetworkProtocol,
private val protocol: NetworkProtocol,
serverType: ServerType,
) : BaseConformanceTest(protocol, serverType) {
companion object {
private val responseHeaders = mapOf(Pair("x-grpc-test-echo-initial", listOf("test_initial_metadata_value")))
private val responseTrailers = mapOf(Pair("x-grpc-test-echo-trailing-bin", listOf("CgsKCwoL"))) // base64-encoded 0x0a0b0a0b0a0b
private val requestHeaders = responseHeaders + responseTrailers
}

private lateinit var unimplementedServiceClient: UnimplementedServiceClient
private lateinit var testServiceConnectClient: TestServiceClient
Expand All @@ -68,20 +73,22 @@ class ConformanceTest(
@Test
fun serverStreaming(): Unit = runBlocking {
val sizes = listOf(512_000, 16, 2_028, 65_536)
val stream = testServiceConnectClient.streamingOutputCall()
val stream = testServiceConnectClient.streamingOutputCall(requestHeaders)
val params = sizes.map { responseParameters { size = it } }.toList()
stream.sendAndClose(
streamingOutputCallRequest {
responseType = PayloadType.COMPRESSABLE
responseParameters += params
},
).getOrThrow()
assertThat(stream.responseHeaders().await()).containsAllEntriesOf(responseHeaders)
val responses = mutableListOf<StreamingOutputCallResponse>()
for (response in stream.responseChannel()) {
responses.add(response)
}
assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(stream.responseTrailers().await()).containsAllEntriesOf(responseTrailers)
}

@Test
Expand Down Expand Up @@ -673,7 +680,7 @@ class ConformanceTest(

@Test
fun clientStreaming(): Unit = runBlocking {
val stream = testServiceConnectClient.streamingInputCall(emptyMap())
val stream = testServiceConnectClient.streamingInputCall()
var sum = 0
listOf(256000, 8, 1024, 32768).forEach { size ->
stream.send(
Expand All @@ -692,6 +699,13 @@ class ConformanceTest(
try {
val response = stream.receiveAndClose()
assertThat(response.aggregatedPayloadSize).isEqualTo(sum)
assertThat(stream.responseHeaders().isCompleted).isTrue()
assertThat(stream.responseHeaders().await()).isNotEmpty()
if (protocol != NetworkProtocol.CONNECT) {
// gRPC and gRPC-web communicate RPC status in trailers, so
// they should always have something.
assertThat(stream.responseTrailers().await()).isNotEmpty()
}
} finally {
countDownLatch.countDown()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ import java.util.concurrent.TimeUnit

@RunWith(Parameterized::class)
class ConformanceTest(
protocol: NetworkProtocol,
private val protocol: NetworkProtocol,
serverType: ServerType,
) : BaseConformanceTest(protocol, serverType) {
companion object {
private val responseHeaders = mapOf(Pair("x-grpc-test-echo-initial", listOf("test_initial_metadata_value")))
private val responseTrailers = mapOf(Pair("x-grpc-test-echo-trailing-bin", listOf("CgsKCwoL"))) // base64-encoded 0x0a0b0a0b0a0b
private val requestHeaders = responseHeaders + responseTrailers
}

private lateinit var unimplementedServiceClient: UnimplementedServiceClient
private lateinit var testServiceConnectClient: TestServiceClient
Expand All @@ -68,20 +73,22 @@ class ConformanceTest(
@Test
fun serverStreaming(): Unit = runBlocking {
val sizes = listOf(512_000, 16, 2_028, 65_536)
val stream = testServiceConnectClient.streamingOutputCall()
val stream = testServiceConnectClient.streamingOutputCall(requestHeaders)
val params = sizes.map { responseParameters { size = it } }.toList()
stream.sendAndClose(
streamingOutputCallRequest {
responseType = PayloadType.COMPRESSABLE
responseParameters += params
},
).getOrThrow()
assertThat(stream.responseHeaders().await()).containsAllEntriesOf(responseHeaders)
val responses = mutableListOf<StreamingOutputCallResponse>()
for (response in stream.responseChannel()) {
responses.add(response)
}
assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(stream.responseTrailers().await()).containsAllEntriesOf(responseTrailers)
}

@Test
Expand Down Expand Up @@ -673,7 +680,7 @@ class ConformanceTest(

@Test
fun clientStreaming(): Unit = runBlocking {
val stream = testServiceConnectClient.streamingInputCall(emptyMap())
val stream = testServiceConnectClient.streamingInputCall()
var sum = 0
listOf(256000, 8, 1024, 32768).forEach { size ->
stream.send(
Expand All @@ -692,6 +699,13 @@ class ConformanceTest(
try {
val response = stream.receiveAndClose()
assertThat(response.aggregatedPayloadSize).isEqualTo(sum)
assertThat(stream.responseHeaders().isCompleted).isTrue()
assertThat(stream.responseHeaders().await()).isNotEmpty()
if (protocol != NetworkProtocol.CONNECT) {
// gRPC and gRPC-web communicate RPC status in trailers, so
// they should always have something.
assertThat(stream.responseTrailers().await()).isNotEmpty()
}
} finally {
countDownLatch.countDown()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.connectrpc

import kotlinx.coroutines.Deferred
import kotlinx.coroutines.channels.ReceiveChannel

/**
Expand All @@ -27,6 +28,25 @@ interface BidirectionalStreamInterface<Input, Output> {
*/
fun responseChannel(): ReceiveChannel<Output>

/**
* The response headers. This value will become available before any output
* messages become available from the [responseChannel] and before trailers
* are available from [responseTrailers]. If the stream fails before headers
* are ever received, this will complete with an empty value. The
* [ReceiveChannel.receive] method of [responseChannel] can be used to
* recover the exception that caused such a failure.
*/
fun responseHeaders(): Deferred<Headers>

/**
* The response trailers. This value will not become available until the entire
* RPC operation is complete. If the stream fails before trailers are ever
* received, this will complete with an empty value. The [ReceiveChannel.receive]
* method of [responseChannel] can be used to recover the exception that caused
* such a failure.
*/
fun responseTrailers(): Deferred<Headers>

/**
* Send a request to the server over the stream.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package com.connectrpc

import kotlinx.coroutines.Deferred

/**
* Represents a client-only stream (a stream where the client streams data to the server and
* eventually receives a response) that can send request messages and initiate closes.
Expand All @@ -34,6 +36,24 @@ interface ClientOnlyStreamInterface<Input, Output> {
*/
suspend fun receiveAndClose(): Output

/**
* The response headers. This value will become available before any call to
* [receiveAndClose] completes and before trailers are available from
* [responseTrailers] (though these may occur nearly simultaneously). If the
* stream fails before headers are ever received, this will complete with an
* empty value. The [receiveAndClose] method can be used to recover the
* exception that caused such a failure.
*/
fun responseHeaders(): Deferred<Headers>

/**
* The response trailers. This value will not become available until the entire
* RPC operation is complete. If the stream fails before trailers are ever
* received, this will complete with an empty value. The [receiveAndClose]
* method can be used to recover the exception that caused such a failure.
*/
fun responseTrailers(): Deferred<Headers>

/**
* Close the stream. No calls to [send] are valid after calling [sendClose].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.connectrpc

import kotlinx.coroutines.Deferred
import kotlinx.coroutines.channels.ReceiveChannel

/**
Expand All @@ -28,6 +29,25 @@ interface ServerOnlyStreamInterface<Input, Output> {
*/
fun responseChannel(): ReceiveChannel<Output>

/**
* The response headers. This value will become available before any output
* messages become available from the [responseChannel] and before trailers
* are available from [responseTrailers]. If the stream fails before headers
* are ever received, this will complete with an empty value. The
* [ReceiveChannel.receive] method of [responseChannel] can be used to
* recover the exception that caused such a failure.
*/
fun responseHeaders(): Deferred<Headers>

/**
* The response trailers. This value will not become available until the entire
* RPC operation is complete. If the stream fails before trailers are ever
* received, this will complete with an empty value. The [ReceiveChannel.receive]
* method of [responseChannel] can be used to recover the exception that caused
* such a failure.
*/
fun responseTrailers(): Deferred<Headers>

/**
* Send a request to the server over the stream and closes the request.
*
Expand Down
17 changes: 14 additions & 3 deletions library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Stream(
private val isReceiveClosed = AtomicReference(false)

fun send(buffer: Buffer): Result<Unit> {
if (isClosed()) {
if (isSendClosed()) {
return Result.failure(IllegalStateException("cannot send. underlying stream is closed"))
}
return try {
Expand All @@ -75,12 +75,23 @@ class Stream(

fun receiveClose() {
if (!isReceiveClosed.getAndSet(true)) {
onReceiveClose()
try {
onReceiveClose()
} finally {
// When receive side is closed, the send side is
// implicitly closed as well.
// We don't use sendClose() because we don't want to
// invoke onSendClose() since that will try to actually
// half-close the HTTP stream, which will fail since the
// closing the receive side cancels the entire thing.
isSendClosed.set(true)
}
}
}

// TODO: remove this method as it is redundant with receive closed
fun isClosed(): Boolean {
return isSendClosed() && isReceiveClosed()
return isReceiveClosed()
}

fun isSendClosed(): Boolean {
Expand Down
16 changes: 14 additions & 2 deletions library/src/main/kotlin/com/connectrpc/impl/BidirectionalStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package com.connectrpc.impl

import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.Codec
import com.connectrpc.Headers
import com.connectrpc.http.Stream
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import java.lang.Exception
Expand All @@ -27,7 +29,9 @@ import java.lang.Exception
internal class BidirectionalStream<Input, Output>(
val stream: Stream,
private val requestCodec: Codec<Input>,
private val receiveChannel: Channel<Output>,
private val responseChannel: Channel<Output>,
private val responseHeaders: Deferred<Headers>,
private val responseTrailers: Deferred<Headers>,
) : BidirectionalStreamInterface<Input, Output> {

override suspend fun send(input: Input): Result<Unit> {
Expand All @@ -40,7 +44,15 @@ internal class BidirectionalStream<Input, Output>(
}

override fun responseChannel(): ReceiveChannel<Output> {
return receiveChannel
return responseChannel
}

override fun responseHeaders(): Deferred<Headers> {
return responseHeaders
}

override fun responseTrailers(): Deferred<Headers> {
return responseTrailers
}

override fun isClosed(): Boolean {
Expand Down
10 changes: 10 additions & 0 deletions library/src/main/kotlin/com/connectrpc/impl/ClientOnlyStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.ClientOnlyStreamInterface
import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.Headers
import kotlinx.coroutines.Deferred

/**
* Concrete implementation of [ClientOnlyStreamInterface].
Expand All @@ -44,6 +46,14 @@ internal class ClientOnlyStream<Input, Output>(
}
}

override fun responseHeaders(): Deferred<Headers> {
return messageStream.responseHeaders()
}

override fun responseTrailers(): Deferred<Headers> {
return messageStream.responseTrailers()
}

override fun sendClose() {
return messageStream.sendClose()
}
Expand Down
Loading

0 comments on commit 92cc2a7

Please sign in to comment.