Skip to content

Commit 2678d89

Browse files
committed
Improve API for bidi and server streaming calls
Instead of requiring callers to handle oneOf(Headers,Message,Trailers) objects in each bidi or server streaming call, instead just change the response channel to return the response message type. If an error occurs at the end of the call (due to non-zero grpc-status), then cancel the channel with an exception.
1 parent 34cf5b1 commit 2678d89

File tree

15 files changed

+103
-248
lines changed

15 files changed

+103
-248
lines changed

conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt

Lines changed: 32 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@ package com.connectrpc.conformance
1616

1717
import com.connectrpc.Code
1818
import com.connectrpc.ConnectException
19-
import com.connectrpc.Headers
2019
import com.connectrpc.ProtocolClientConfig
2120
import com.connectrpc.RequestCompression
22-
import com.connectrpc.StreamResult
23-
import com.connectrpc.Trailers
2421
import com.connectrpc.compression.GzipCompressionPool
2522
import com.connectrpc.conformance.ssl.sslContext
2623
import com.connectrpc.conformance.v1.ErrorDetail
2724
import com.connectrpc.conformance.v1.PayloadType
25+
import com.connectrpc.conformance.v1.StreamingOutputCallResponse
2826
import com.connectrpc.conformance.v1.TestServiceClient
2927
import com.connectrpc.conformance.v1.UnimplementedServiceClient
3028
import com.connectrpc.conformance.v1.echoStatus
@@ -43,7 +41,6 @@ import com.google.protobuf.ByteString
4341
import com.google.protobuf.empty
4442
import kotlinx.coroutines.Dispatchers
4543
import kotlinx.coroutines.async
46-
import kotlinx.coroutines.channels.ReceiveChannel
4744
import kotlinx.coroutines.launch
4845
import kotlinx.coroutines.runBlocking
4946
import kotlinx.coroutines.withContext
@@ -63,7 +60,6 @@ import java.time.Duration
6360
import java.util.Base64
6461
import java.util.concurrent.CountDownLatch
6562
import java.util.concurrent.TimeUnit
66-
import java.util.concurrent.atomic.AtomicBoolean
6763

6864
@RunWith(Parameterized::class)
6965
class Conformance(
@@ -177,17 +173,18 @@ class Conformance(
177173
responseParameters += params
178174
},
179175
).getOrThrow()
180-
val results = streamResults(stream.resultChannel())
181-
assertThat(results.cause).isNull()
182-
assertThat(results.code).isEqualTo(Code.OK)
183-
assertThat(results.messages.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
184-
assertThat(results.messages.map { it.payload.body.size() }).isEqualTo(sizes)
176+
val responses = mutableListOf<StreamingOutputCallResponse>()
177+
for (response in stream.responseChannel()) {
178+
responses.add(response)
179+
}
180+
assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
181+
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
185182
}
186183

187184
@Test
188185
fun pingPong(): Unit = runBlocking {
189186
val stream = testServiceConnectClient.fullDuplexCall()
190-
var readHeaders = false
187+
val responseChannel = stream.responseChannel()
191188
listOf(512_000, 16, 2_028, 65_536).forEach {
192189
val param = responseParameters { size = it }
193190
stream.send(
@@ -196,25 +193,14 @@ class Conformance(
196193
responseParameters += param
197194
},
198195
).getOrThrow()
199-
if (!readHeaders) {
200-
val headersResult = stream.resultChannel().receive()
201-
assertThat(headersResult).isInstanceOf(StreamResult.Headers::class.java)
202-
readHeaders = true
203-
}
204-
val result = stream.resultChannel().receive()
205-
assertThat(result).isInstanceOf(StreamResult.Message::class.java)
206-
val messageResult = result as StreamResult.Message
207-
val payload = messageResult.message.payload
196+
val response = responseChannel.receive()
197+
val payload = response.payload
208198
assertThat(payload.type).isEqualTo(PayloadType.COMPRESSABLE)
209199
assertThat(payload.body).hasSize(it)
210200
}
211201
stream.sendClose()
212-
val results = streamResults(stream.resultChannel())
213202
// We've already read all the messages
214-
assertThat(results.messages).isEmpty()
215-
assertThat(results.cause).isNull()
216-
assertThat(results.code).isEqualTo(Code.OK)
217-
stream.receiveClose()
203+
assertThat(responseChannel.receiveCatching().isClosed).isTrue()
218204
}
219205

220206
@Test
@@ -244,15 +230,17 @@ class Conformance(
244230
val countDownLatch = CountDownLatch(1)
245231
withContext(Dispatchers.IO) {
246232
val job = async {
233+
val responses = mutableListOf<StreamingOutputCallResponse>()
247234
try {
248-
val result = streamResults(stream.resultChannel())
249-
assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes)
250-
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
251-
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
252-
val connectException = result.cause as ConnectException
253-
assertThat(connectException.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
254-
assertThat(connectException.message).isEqualTo("soirée 🎉")
255-
assertThat(connectException.unpackedDetails(ErrorDetail::class)).containsExactly(
235+
for (response in stream.responseChannel()) {
236+
responses.add(response)
237+
}
238+
fail("expected call to fail with ConnectException")
239+
} catch (e: ConnectException) {
240+
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
241+
assertThat(e.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
242+
assertThat(e.message).isEqualTo("soirée 🎉")
243+
assertThat(e.unpackedDetails(ErrorDetail::class)).containsExactly(
256244
expectedErrorDetail,
257245
)
258246
} finally {
@@ -363,10 +351,11 @@ class Conformance(
363351
withContext(Dispatchers.IO) {
364352
val job = launch {
365353
try {
366-
val result = streamResults(stream.resultChannel())
367-
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
368-
assertThat(result.code)
369-
.withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${result.code}" }
354+
stream.responseChannel().receive()
355+
fail("unexpected ConnectException to be thrown")
356+
} catch (e: ConnectException) {
357+
assertThat(e.code)
358+
.withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${e.code}" }
370359
.isEqualTo(Code.DEADLINE_EXCEEDED)
371360
} finally {
372361
countDownLatch.countDown()
@@ -437,11 +426,10 @@ class Conformance(
437426
withContext(Dispatchers.IO) {
438427
val job = async {
439428
try {
440-
val result = streamResults(stream.resultChannel())
441-
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
442-
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
443-
val exception = result.cause as ConnectException
444-
assertThat(exception.code).isEqualTo(Code.UNIMPLEMENTED)
429+
stream.responseChannel().receive()
430+
fail("expected call to fail with a ConnectException")
431+
} catch (e: ConnectException) {
432+
assertThat(e.code).isEqualTo(Code.UNIMPLEMENTED)
445433
} finally {
446434
countDownLatch.countDown()
447435
}
@@ -801,8 +789,8 @@ class Conformance(
801789
withContext(Dispatchers.IO) {
802790
val job = async {
803791
try {
804-
val result = stream.receiveAndClose().getOrThrow()
805-
assertThat(result.aggregatedPayloadSize).isEqualTo(sum)
792+
val response = stream.receiveAndClose()
793+
assertThat(response.aggregatedPayloadSize).isEqualTo(sum)
806794
} finally {
807795
countDownLatch.countDown()
808796
}
@@ -813,56 +801,6 @@ class Conformance(
813801
}
814802
}
815803

816-
private data class ServerStreamingResult<Output>(
817-
val headers: Headers,
818-
val messages: List<Output>,
819-
val code: Code,
820-
val trailers: Trailers,
821-
val cause: Throwable?,
822-
)
823-
824-
/*
825-
* Convenience method to return all results (with sanity checking) for calls which stream results from the server
826-
* (bidi and server streaming).
827-
*
828-
* This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold
829-
* manually in each location.
830-
*/
831-
private suspend fun <Output> streamResults(channel: ReceiveChannel<StreamResult<Output>>): ServerStreamingResult<Output> {
832-
val seenHeaders = AtomicBoolean(false)
833-
var headers: Headers = emptyMap()
834-
val messages: MutableList<Output> = mutableListOf()
835-
val seenCompletion = AtomicBoolean(false)
836-
var code: Code = Code.UNKNOWN
837-
var trailers: Headers = emptyMap()
838-
var error: Throwable? = null
839-
for (response in channel) {
840-
response.maybeFold(
841-
onHeaders = {
842-
if (!seenHeaders.compareAndSet(false, true)) {
843-
throw IllegalStateException("multiple onHeaders callbacks")
844-
}
845-
headers = it.headers
846-
},
847-
onMessage = {
848-
messages.add(it.message)
849-
},
850-
onCompletion = {
851-
if (!seenCompletion.compareAndSet(false, true)) {
852-
throw IllegalStateException("multiple onCompletion callbacks")
853-
}
854-
code = it.code
855-
trailers = it.trailers
856-
error = it.cause
857-
},
858-
)
859-
}
860-
if (!seenCompletion.get()) {
861-
throw IllegalStateException("didn't get completion message")
862-
}
863-
return ServerStreamingResult(headers, messages, code, trailers, error)
864-
}
865-
866804
private fun b64Encode(trailingValue: ByteArray): String {
867805
return String(Base64.getEncoder().encode(trailingValue))
868806
}

examples/android/src/main/kotlin/com/connectrpc/examples/android/ElizaChatActivity.kt

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import android.widget.TextView
2121
import androidx.appcompat.app.AppCompatActivity
2222
import androidx.lifecycle.lifecycleScope
2323
import androidx.recyclerview.widget.RecyclerView
24+
import com.connectrpc.ConnectException
2425
import com.connectrpc.ProtocolClientConfig
2526
import com.connectrpc.eliza.v1.ConverseRequest
2627
import com.connectrpc.eliza.v1.ElizaServiceClient
@@ -135,29 +136,26 @@ class ElizaChatActivity : AppCompatActivity() {
135136
lifecycleScope.launch(Dispatchers.IO) {
136137
// Initialize a bidi stream with Eliza.
137138
val stream = elizaServiceClient.converse()
138-
139-
for (streamResult in stream.resultChannel()) {
140-
streamResult.maybeFold(
141-
onMessage = { result ->
142-
// A stream message is received: Eliza has said something to us.
143-
val elizaResponse = result.message.sentence
144-
if (elizaResponse?.isNotBlank() == true) {
145-
adapter.add(MessageData(elizaResponse, true))
146-
} else {
147-
// Something odd occurred.
148-
adapter.add(MessageData("...No response from Eliza...", true))
149-
}
150-
},
151-
onCompletion = {
152-
// This should only be called once.
153-
adapter.add(
154-
MessageData(
155-
"Session has ended.",
156-
true,
157-
),
158-
)
159-
},
139+
try {
140+
for (message in stream.responseChannel()) {
141+
// A stream message is received: Eliza has said something to us.
142+
val elizaResponse = message.sentence
143+
if (elizaResponse?.isNotBlank() == true) {
144+
adapter.add(MessageData(elizaResponse, true))
145+
} else {
146+
// Something odd occurred.
147+
adapter.add(MessageData("...No response from Eliza...", true))
148+
}
149+
}
150+
// This should only be called once.
151+
adapter.add(
152+
MessageData(
153+
"Session has ended.",
154+
true,
155+
),
160156
)
157+
} catch (e: ConnectException) {
158+
adapter.add(MessageData("Session failed with code ${e.code}", true))
161159
}
162160
lifecycleScope.launch(Dispatchers.Main) {
163161
buttonView.setOnClickListener {

examples/kotlin-google-java/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
package com.connectrpc.examples.kotlin
1616

17-
import com.connectrpc.Code
1817
import com.connectrpc.ConnectException
1918
import com.connectrpc.ProtocolClientConfig
2019
import com.connectrpc.eliza.v1.ElizaServiceClient
@@ -63,23 +62,13 @@ class Main {
6362
// Add the message the user is sending to the views.
6463
stream.send(converseRequest { sentence = "hello" })
6564
stream.sendClose()
66-
for (streamResult in stream.resultChannel()) {
67-
streamResult.maybeFold(
68-
onMessage = { result ->
69-
// Update the view with the response.
70-
val elizaResponse = result.message
71-
println(elizaResponse.sentence)
72-
},
73-
onCompletion = { result ->
74-
if (result.code != Code.OK) {
75-
val exception = result.connectException()
76-
if (exception != null) {
77-
throw exception
78-
}
79-
throw ConnectException(code = result.code, metadata = result.trailers)
80-
}
81-
},
82-
)
65+
try {
66+
for (streamResult in stream.responseChannel()) {
67+
// Update the view with the response.
68+
val elizaResponse = streamResult
69+
println(elizaResponse.sentence)
70+
}
71+
} catch (e: ConnectException) {
8372
}
8473
}
8574
}

examples/kotlin-google-javalite/src/main/kotlin/com/connectrpc/examples/kotlin/Main.kt

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
package com.connectrpc.examples.kotlin
1616

17-
import com.connectrpc.Code
18-
import com.connectrpc.ConnectException
1917
import com.connectrpc.ProtocolClientConfig
2018
import com.connectrpc.eliza.v1.ElizaServiceClient
2119
import com.connectrpc.eliza.v1.converseRequest
@@ -34,7 +32,7 @@ class Main {
3432
@JvmStatic
3533
fun main(args: Array<String>) {
3634
runBlocking {
37-
val host = "https://demo.connectrpc.com"
35+
val host = "https://demo.connectrpc.com:444"
3836
val okHttpClient = OkHttpClient()
3937
.newBuilder()
4038
.readTimeout(Duration.ofMinutes(10))
@@ -63,23 +61,8 @@ class Main {
6361
// Add the message the user is sending to the views.
6462
stream.send(converseRequest { sentence = "hello" })
6563
stream.sendClose()
66-
for (streamResult in stream.resultChannel()) {
67-
streamResult.maybeFold(
68-
onMessage = { result ->
69-
// Update the view with the response.
70-
val elizaResponse = result.message
71-
println(elizaResponse.sentence)
72-
},
73-
onCompletion = { result ->
74-
if (result.code != Code.OK) {
75-
val exception = result.connectException()
76-
if (exception != null) {
77-
throw exception
78-
}
79-
throw ConnectException(code = result.code, metadata = result.trailers)
80-
}
81-
},
82-
)
64+
for (response in stream.responseChannel()) {
65+
println(response.sentence)
8366
}
8467
}
8568
}

library/src/main/kotlin/com/connectrpc/BidirectionalStreamInterface.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ import kotlinx.coroutines.channels.ReceiveChannel
2121
*/
2222
interface BidirectionalStreamInterface<Input, Output> {
2323
/**
24-
* The Channel for received StreamResults.
24+
* The Channel for responses.
2525
*
26-
* @return ReceiveChannel for iterating over the received results.
26+
* @return ReceiveChannel for iterating over the responses.
2727
*/
28-
fun resultChannel(): ReceiveChannel<StreamResult<Output>>
28+
fun responseChannel(): ReceiveChannel<Output>
2929

3030
/**
3131
* Send a request to the server over the stream.

library/src/main/kotlin/com/connectrpc/ClientOnlyStreamInterface.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ interface ClientOnlyStreamInterface<Input, Output> {
2929
/**
3030
* Receive a single response and close the stream.
3131
*
32-
* @return the single response [ResponseMessage].
32+
* @return the single response [Output].
3333
*/
34-
suspend fun receiveAndClose(): ResponseMessage<Output>
34+
suspend fun receiveAndClose(): Output
3535

3636
/**
3737
* Close the stream. No calls to [send] are valid after calling [sendClose].

library/src/main/kotlin/com/connectrpc/ServerOnlyStreamInterface.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package com.connectrpc
1616

1717
import kotlinx.coroutines.channels.ReceiveChannel
18+
1819
/**
1920
* Represents a server-only stream (a stream where the server streams data to the client after
2021
* receiving an initial request) that can send request messages.
@@ -25,7 +26,7 @@ interface ServerOnlyStreamInterface<Input, Output> {
2526
*
2627
* @return ReceiveChannel for iterating over the received results.
2728
*/
28-
fun resultChannel(): ReceiveChannel<StreamResult<Output>>
29+
fun responseChannel(): ReceiveChannel<Output>
2930

3031
/**
3132
* Send a request to the server over the stream and closes the request.

0 commit comments

Comments
 (0)