Skip to content

Commit

Permalink
Improve API for bidi and server streaming calls
Browse files Browse the repository at this point in the history
Instead of requiring callers to handle oneOf(Headers,Message,Trailers)
objects in each bidi or server streaming call, instead just change the
response channel to return the response message type. If an error occurs
at the end of the call (due to non-zero grpc-status), then cancel the
channel with an exception.
  • Loading branch information
pkwarren committed Oct 18, 2023
1 parent 34cf5b1 commit 2678d89
Show file tree
Hide file tree
Showing 15 changed files with 103 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ package com.connectrpc.conformance

import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.Headers
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.StreamResult
import com.connectrpc.Trailers
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.conformance.ssl.sslContext
import com.connectrpc.conformance.v1.ErrorDetail
import com.connectrpc.conformance.v1.PayloadType
import com.connectrpc.conformance.v1.StreamingOutputCallResponse
import com.connectrpc.conformance.v1.TestServiceClient
import com.connectrpc.conformance.v1.UnimplementedServiceClient
import com.connectrpc.conformance.v1.echoStatus
Expand All @@ -43,7 +41,6 @@ import com.google.protobuf.ByteString
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
Expand All @@ -63,7 +60,6 @@ import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean

@RunWith(Parameterized::class)
class Conformance(
Expand Down Expand Up @@ -177,17 +173,18 @@ class Conformance(
responseParameters += params
},
).getOrThrow()
val results = streamResults(stream.resultChannel())
assertThat(results.cause).isNull()
assertThat(results.code).isEqualTo(Code.OK)
assertThat(results.messages.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
assertThat(results.messages.map { it.payload.body.size() }).isEqualTo(sizes)
val responses = mutableListOf<StreamingOutputCallResponse>()
for (response in stream.responseChannel()) {
responses.add(response)
}
assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
}

@Test
fun pingPong(): Unit = runBlocking {
val stream = testServiceConnectClient.fullDuplexCall()
var readHeaders = false
val responseChannel = stream.responseChannel()
listOf(512_000, 16, 2_028, 65_536).forEach {
val param = responseParameters { size = it }
stream.send(
Expand All @@ -196,25 +193,14 @@ class Conformance(
responseParameters += param
},
).getOrThrow()
if (!readHeaders) {
val headersResult = stream.resultChannel().receive()
assertThat(headersResult).isInstanceOf(StreamResult.Headers::class.java)
readHeaders = true
}
val result = stream.resultChannel().receive()
assertThat(result).isInstanceOf(StreamResult.Message::class.java)
val messageResult = result as StreamResult.Message
val payload = messageResult.message.payload
val response = responseChannel.receive()
val payload = response.payload
assertThat(payload.type).isEqualTo(PayloadType.COMPRESSABLE)
assertThat(payload.body).hasSize(it)
}
stream.sendClose()
val results = streamResults(stream.resultChannel())
// We've already read all the messages
assertThat(results.messages).isEmpty()
assertThat(results.cause).isNull()
assertThat(results.code).isEqualTo(Code.OK)
stream.receiveClose()
assertThat(responseChannel.receiveCatching().isClosed).isTrue()
}

@Test
Expand Down Expand Up @@ -244,15 +230,17 @@ class Conformance(
val countDownLatch = CountDownLatch(1)
withContext(Dispatchers.IO) {
val job = async {
val responses = mutableListOf<StreamingOutputCallResponse>()
try {
val result = streamResults(stream.resultChannel())
assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
val connectException = result.cause as ConnectException
assertThat(connectException.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(connectException.message).isEqualTo("soirée 🎉")
assertThat(connectException.unpackedDetails(ErrorDetail::class)).containsExactly(
for (response in stream.responseChannel()) {
responses.add(response)
}
fail("expected call to fail with ConnectException")
} catch (e: ConnectException) {
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(e.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(e.message).isEqualTo("soirée 🎉")
assertThat(e.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail,
)
} finally {
Expand Down Expand Up @@ -363,10 +351,11 @@ class Conformance(
withContext(Dispatchers.IO) {
val job = launch {
try {
val result = streamResults(stream.resultChannel())
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
assertThat(result.code)
.withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${result.code}" }
stream.responseChannel().receive()
fail("unexpected ConnectException to be thrown")
} catch (e: ConnectException) {
assertThat(e.code)
.withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${e.code}" }
.isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
Expand Down Expand Up @@ -437,11 +426,10 @@ class Conformance(
withContext(Dispatchers.IO) {
val job = async {
try {
val result = streamResults(stream.resultChannel())
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
val exception = result.cause as ConnectException
assertThat(exception.code).isEqualTo(Code.UNIMPLEMENTED)
stream.responseChannel().receive()
fail("expected call to fail with a ConnectException")
} catch (e: ConnectException) {
assertThat(e.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
Expand Down Expand Up @@ -801,8 +789,8 @@ class Conformance(
withContext(Dispatchers.IO) {
val job = async {
try {
val result = stream.receiveAndClose().getOrThrow()
assertThat(result.aggregatedPayloadSize).isEqualTo(sum)
val response = stream.receiveAndClose()
assertThat(response.aggregatedPayloadSize).isEqualTo(sum)
} finally {
countDownLatch.countDown()
}
Expand All @@ -813,56 +801,6 @@ class Conformance(
}
}

private data class ServerStreamingResult<Output>(
val headers: Headers,
val messages: List<Output>,
val code: Code,
val trailers: Trailers,
val cause: Throwable?,
)

/*
* Convenience method to return all results (with sanity checking) for calls which stream results from the server
* (bidi and server streaming).
*
* This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold
* manually in each location.
*/
private suspend fun <Output> streamResults(channel: ReceiveChannel<StreamResult<Output>>): ServerStreamingResult<Output> {
val seenHeaders = AtomicBoolean(false)
var headers: Headers = emptyMap()
val messages: MutableList<Output> = mutableListOf()
val seenCompletion = AtomicBoolean(false)
var code: Code = Code.UNKNOWN
var trailers: Headers = emptyMap()
var error: Throwable? = null
for (response in channel) {
response.maybeFold(
onHeaders = {
if (!seenHeaders.compareAndSet(false, true)) {
throw IllegalStateException("multiple onHeaders callbacks")
}
headers = it.headers
},
onMessage = {
messages.add(it.message)
},
onCompletion = {
if (!seenCompletion.compareAndSet(false, true)) {
throw IllegalStateException("multiple onCompletion callbacks")
}
code = it.code
trailers = it.trailers
error = it.cause
},
)
}
if (!seenCompletion.get()) {
throw IllegalStateException("didn't get completion message")
}
return ServerStreamingResult(headers, messages, code, trailers, error)
}

private fun b64Encode(trailingValue: ByteArray): String {
return String(Base64.getEncoder().encode(trailingValue))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.lifecycle.lifecycleScope
import androidx.recyclerview.widget.RecyclerView
import com.connectrpc.ConnectException
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ConverseRequest
import com.connectrpc.eliza.v1.ElizaServiceClient
Expand Down Expand Up @@ -135,29 +136,26 @@ class ElizaChatActivity : AppCompatActivity() {
lifecycleScope.launch(Dispatchers.IO) {
// Initialize a bidi stream with Eliza.
val stream = elizaServiceClient.converse()

for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// A stream message is received: Eliza has said something to us.
val elizaResponse = result.message.sentence
if (elizaResponse?.isNotBlank() == true) {
adapter.add(MessageData(elizaResponse, true))
} else {
// Something odd occurred.
adapter.add(MessageData("...No response from Eliza...", true))
}
},
onCompletion = {
// This should only be called once.
adapter.add(
MessageData(
"Session has ended.",
true,
),
)
},
try {
for (message in stream.responseChannel()) {
// A stream message is received: Eliza has said something to us.
val elizaResponse = message.sentence
if (elizaResponse?.isNotBlank() == true) {
adapter.add(MessageData(elizaResponse, true))
} else {
// Something odd occurred.
adapter.add(MessageData("...No response from Eliza...", true))
}
}
// This should only be called once.
adapter.add(
MessageData(
"Session has ended.",
true,
),
)
} catch (e: ConnectException) {
adapter.add(MessageData("Session failed with code ${e.code}", true))
}
lifecycleScope.launch(Dispatchers.Main) {
buttonView.setOnClickListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

package com.connectrpc.examples.kotlin

import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ElizaServiceClient
Expand Down Expand Up @@ -63,23 +62,13 @@ class Main {
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// Update the view with the response.
val elizaResponse = result.message
println(elizaResponse.sentence)
},
onCompletion = { result ->
if (result.code != Code.OK) {
val exception = result.connectException()
if (exception != null) {
throw exception
}
throw ConnectException(code = result.code, metadata = result.trailers)
}
},
)
try {
for (streamResult in stream.responseChannel()) {
// Update the view with the response.
val elizaResponse = streamResult
println(elizaResponse.sentence)
}
} catch (e: ConnectException) {
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package com.connectrpc.examples.kotlin

import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ElizaServiceClient
import com.connectrpc.eliza.v1.converseRequest
Expand All @@ -34,7 +32,7 @@ class Main {
@JvmStatic
fun main(args: Array<String>) {
runBlocking {
val host = "https://demo.connectrpc.com"
val host = "https://demo.connectrpc.com:444"
val okHttpClient = OkHttpClient()
.newBuilder()
.readTimeout(Duration.ofMinutes(10))
Expand Down Expand Up @@ -63,23 +61,8 @@ class Main {
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// Update the view with the response.
val elizaResponse = result.message
println(elizaResponse.sentence)
},
onCompletion = { result ->
if (result.code != Code.OK) {
val exception = result.connectException()
if (exception != null) {
throw exception
}
throw ConnectException(code = result.code, metadata = result.trailers)
}
},
)
for (response in stream.responseChannel()) {
println(response.sentence)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import kotlinx.coroutines.channels.ReceiveChannel
*/
interface BidirectionalStreamInterface<Input, Output> {
/**
* The Channel for received StreamResults.
* The Channel for responses.
*
* @return ReceiveChannel for iterating over the received results.
* @return ReceiveChannel for iterating over the responses.
*/
fun resultChannel(): ReceiveChannel<StreamResult<Output>>
fun responseChannel(): ReceiveChannel<Output>

/**
* Send a request to the server over the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ interface ClientOnlyStreamInterface<Input, Output> {
/**
* Receive a single response and close the stream.
*
* @return the single response [ResponseMessage].
* @return the single response [Output].
*/
suspend fun receiveAndClose(): ResponseMessage<Output>
suspend fun receiveAndClose(): Output

/**
* Close the stream. No calls to [send] are valid after calling [sendClose].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.connectrpc

import kotlinx.coroutines.channels.ReceiveChannel

/**
* Represents a server-only stream (a stream where the server streams data to the client after
* receiving an initial request) that can send request messages.
Expand All @@ -25,7 +26,7 @@ interface ServerOnlyStreamInterface<Input, Output> {
*
* @return ReceiveChannel for iterating over the received results.
*/
fun resultChannel(): ReceiveChannel<StreamResult<Output>>
fun responseChannel(): ReceiveChannel<Output>

/**
* Send a request to the server over the stream and closes the request.
Expand Down
Loading

0 comments on commit 2678d89

Please sign in to comment.