@@ -16,15 +16,13 @@ package com.connectrpc.conformance
16
16
17
17
import com.connectrpc.Code
18
18
import com.connectrpc.ConnectException
19
- import com.connectrpc.Headers
20
19
import com.connectrpc.ProtocolClientConfig
21
20
import com.connectrpc.RequestCompression
22
- import com.connectrpc.StreamResult
23
- import com.connectrpc.Trailers
24
21
import com.connectrpc.compression.GzipCompressionPool
25
22
import com.connectrpc.conformance.ssl.sslContext
26
23
import com.connectrpc.conformance.v1.ErrorDetail
27
24
import com.connectrpc.conformance.v1.PayloadType
25
+ import com.connectrpc.conformance.v1.StreamingOutputCallResponse
28
26
import com.connectrpc.conformance.v1.TestServiceClient
29
27
import com.connectrpc.conformance.v1.UnimplementedServiceClient
30
28
import com.connectrpc.conformance.v1.echoStatus
@@ -43,7 +41,6 @@ import com.google.protobuf.ByteString
43
41
import com.google.protobuf.empty
44
42
import kotlinx.coroutines.Dispatchers
45
43
import kotlinx.coroutines.async
46
- import kotlinx.coroutines.channels.ReceiveChannel
47
44
import kotlinx.coroutines.launch
48
45
import kotlinx.coroutines.runBlocking
49
46
import kotlinx.coroutines.withContext
@@ -63,7 +60,6 @@ import java.time.Duration
63
60
import java.util.Base64
64
61
import java.util.concurrent.CountDownLatch
65
62
import java.util.concurrent.TimeUnit
66
- import java.util.concurrent.atomic.AtomicBoolean
67
63
68
64
@RunWith(Parameterized ::class )
69
65
class Conformance (
@@ -177,17 +173,18 @@ class Conformance(
177
173
responseParameters + = params
178
174
},
179
175
).getOrThrow()
180
- val results = streamResults(stream.resultChannel())
181
- assertThat(results.cause).isNull()
182
- assertThat(results.code).isEqualTo(Code .OK )
183
- assertThat(results.messages.map { it.payload.type }.toSet()).isEqualTo(setOf (PayloadType .COMPRESSABLE ))
184
- assertThat(results.messages.map { it.payload.body.size() }).isEqualTo(sizes)
176
+ val responses = mutableListOf<StreamingOutputCallResponse >()
177
+ for (response in stream.responseChannel()) {
178
+ responses.add(response)
179
+ }
180
+ assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf (PayloadType .COMPRESSABLE ))
181
+ assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
185
182
}
186
183
187
184
@Test
188
185
fun pingPong (): Unit = runBlocking {
189
186
val stream = testServiceConnectClient.fullDuplexCall()
190
- var readHeaders = false
187
+ val responseChannel = stream.responseChannel()
191
188
listOf (512_000 , 16 , 2_028 , 65_536 ).forEach {
192
189
val param = responseParameters { size = it }
193
190
stream.send(
@@ -196,25 +193,14 @@ class Conformance(
196
193
responseParameters + = param
197
194
},
198
195
).getOrThrow()
199
- if (! readHeaders) {
200
- val headersResult = stream.resultChannel().receive()
201
- assertThat(headersResult).isInstanceOf(StreamResult .Headers ::class .java)
202
- readHeaders = true
203
- }
204
- val result = stream.resultChannel().receive()
205
- assertThat(result).isInstanceOf(StreamResult .Message ::class .java)
206
- val messageResult = result as StreamResult .Message
207
- val payload = messageResult.message.payload
196
+ val response = responseChannel.receive()
197
+ val payload = response.payload
208
198
assertThat(payload.type).isEqualTo(PayloadType .COMPRESSABLE )
209
199
assertThat(payload.body).hasSize(it)
210
200
}
211
201
stream.sendClose()
212
- val results = streamResults(stream.resultChannel())
213
202
// We've already read all the messages
214
- assertThat(results.messages).isEmpty()
215
- assertThat(results.cause).isNull()
216
- assertThat(results.code).isEqualTo(Code .OK )
217
- stream.receiveClose()
203
+ assertThat(responseChannel.receiveCatching().isClosed).isTrue()
218
204
}
219
205
220
206
@Test
@@ -244,15 +230,17 @@ class Conformance(
244
230
val countDownLatch = CountDownLatch (1 )
245
231
withContext(Dispatchers .IO ) {
246
232
val job = async {
233
+ val responses = mutableListOf<StreamingOutputCallResponse >()
247
234
try {
248
- val result = streamResults(stream.resultChannel())
249
- assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes)
250
- assertThat(result.code).isEqualTo(Code .RESOURCE_EXHAUSTED )
251
- assertThat(result.cause).isInstanceOf(ConnectException ::class .java)
252
- val connectException = result.cause as ConnectException
253
- assertThat(connectException.code).isEqualTo(Code .RESOURCE_EXHAUSTED )
254
- assertThat(connectException.message).isEqualTo(" soirée 🎉" )
255
- assertThat(connectException.unpackedDetails(ErrorDetail ::class )).containsExactly(
235
+ for (response in stream.responseChannel()) {
236
+ responses.add(response)
237
+ }
238
+ fail(" expected call to fail with ConnectException" )
239
+ } catch (e: ConnectException ) {
240
+ assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
241
+ assertThat(e.code).isEqualTo(Code .RESOURCE_EXHAUSTED )
242
+ assertThat(e.message).isEqualTo(" soirée 🎉" )
243
+ assertThat(e.unpackedDetails(ErrorDetail ::class )).containsExactly(
256
244
expectedErrorDetail,
257
245
)
258
246
} finally {
@@ -363,10 +351,11 @@ class Conformance(
363
351
withContext(Dispatchers .IO ) {
364
352
val job = launch {
365
353
try {
366
- val result = streamResults(stream.resultChannel())
367
- assertThat(result.cause).isInstanceOf(ConnectException ::class .java)
368
- assertThat(result.code)
369
- .withFailMessage { " Expected Code.DEADLINE_EXCEEDED but got ${result.code} " }
354
+ stream.responseChannel().receive()
355
+ fail(" unexpected ConnectException to be thrown" )
356
+ } catch (e: ConnectException ) {
357
+ assertThat(e.code)
358
+ .withFailMessage { " Expected Code.DEADLINE_EXCEEDED but got ${e.code} " }
370
359
.isEqualTo(Code .DEADLINE_EXCEEDED )
371
360
} finally {
372
361
countDownLatch.countDown()
@@ -437,11 +426,10 @@ class Conformance(
437
426
withContext(Dispatchers .IO ) {
438
427
val job = async {
439
428
try {
440
- val result = streamResults(stream.resultChannel())
441
- assertThat(result.code).isEqualTo(Code .UNIMPLEMENTED )
442
- assertThat(result.cause).isInstanceOf(ConnectException ::class .java)
443
- val exception = result.cause as ConnectException
444
- assertThat(exception.code).isEqualTo(Code .UNIMPLEMENTED )
429
+ stream.responseChannel().receive()
430
+ fail(" expected call to fail with a ConnectException" )
431
+ } catch (e: ConnectException ) {
432
+ assertThat(e.code).isEqualTo(Code .UNIMPLEMENTED )
445
433
} finally {
446
434
countDownLatch.countDown()
447
435
}
@@ -801,8 +789,8 @@ class Conformance(
801
789
withContext(Dispatchers .IO ) {
802
790
val job = async {
803
791
try {
804
- val result = stream.receiveAndClose().getOrThrow ()
805
- assertThat(result .aggregatedPayloadSize).isEqualTo(sum)
792
+ val response = stream.receiveAndClose()
793
+ assertThat(response .aggregatedPayloadSize).isEqualTo(sum)
806
794
} finally {
807
795
countDownLatch.countDown()
808
796
}
@@ -813,56 +801,6 @@ class Conformance(
813
801
}
814
802
}
815
803
816
- private data class ServerStreamingResult <Output >(
817
- val headers : Headers ,
818
- val messages : List <Output >,
819
- val code : Code ,
820
- val trailers : Trailers ,
821
- val cause : Throwable ? ,
822
- )
823
-
824
- /*
825
- * Convenience method to return all results (with sanity checking) for calls which stream results from the server
826
- * (bidi and server streaming).
827
- *
828
- * This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold
829
- * manually in each location.
830
- */
831
- private suspend fun <Output > streamResults (channel : ReceiveChannel <StreamResult <Output >>): ServerStreamingResult <Output > {
832
- val seenHeaders = AtomicBoolean (false )
833
- var headers: Headers = emptyMap()
834
- val messages: MutableList <Output > = mutableListOf ()
835
- val seenCompletion = AtomicBoolean (false )
836
- var code: Code = Code .UNKNOWN
837
- var trailers: Headers = emptyMap()
838
- var error: Throwable ? = null
839
- for (response in channel) {
840
- response.maybeFold(
841
- onHeaders = {
842
- if (! seenHeaders.compareAndSet(false , true )) {
843
- throw IllegalStateException (" multiple onHeaders callbacks" )
844
- }
845
- headers = it.headers
846
- },
847
- onMessage = {
848
- messages.add(it.message)
849
- },
850
- onCompletion = {
851
- if (! seenCompletion.compareAndSet(false , true )) {
852
- throw IllegalStateException (" multiple onCompletion callbacks" )
853
- }
854
- code = it.code
855
- trailers = it.trailers
856
- error = it.cause
857
- },
858
- )
859
- }
860
- if (! seenCompletion.get()) {
861
- throw IllegalStateException (" didn't get completion message" )
862
- }
863
- return ServerStreamingResult (headers, messages, code, trailers, error)
864
- }
865
-
866
804
private fun b64Encode (trailingValue : ByteArray ): String {
867
805
return String (Base64 .getEncoder().encode(trailingValue))
868
806
}
0 commit comments