Skip to content

Commit

Permalink
provide way to auto-close streams
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump committed Jan 31, 2024
1 parent 328110c commit c576586
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import com.connectrpc.conformance.client.adapt.Invoker
import com.connectrpc.conformance.client.adapt.ResponseStream
import com.connectrpc.conformance.client.adapt.ServerStreamClient
import com.connectrpc.conformance.client.adapt.UnaryClient
import com.connectrpc.conformance.client.adapt.execute
import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
Expand All @@ -46,7 +47,6 @@ import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import java.security.KeyFactory
Expand Down Expand Up @@ -163,8 +163,7 @@ class Client(
) {
throw RuntimeException("client stream calls can only support `BeforeCloseSend` and 'AfterCloseSendMs' cancellation field, instead got ${req.cancel!!::class.simpleName}")
}
val stream = client.execute(req.requestHeaders)
try {
client.execute(req.requestHeaders) { stream ->
var numUnsent = 0
for (i in req.requestMessages.indices) {
if (req.requestDelayMs > 0) {
Expand All @@ -184,22 +183,20 @@ class Client(
}
when (val cancel = req.cancel) {
is Cancel.BeforeCloseSend -> {
stream.cancel()
stream.close()
}
is Cancel.AfterCloseSendMs -> {
launch {
delay(cancel.millis.toLong())
stream.cancel()
stream.close()
}
}
else -> {
// We already validated the case above.
// So this case means no cancellation.
}
}
return@coroutineScope unaryResult(numUnsent, stream.closeAndReceive())
} finally {
stream.cancel()
unaryResult(numUnsent, stream.closeAndReceive())
}
}

Expand All @@ -220,37 +217,34 @@ class Client(
throw RuntimeException("server stream calls can only support `AfterCloseSendMs` and 'AfterNumResponses' cancellation field, instead got ${req.cancel!!::class.simpleName}")
}
val msg = fromAny(req.requestMessages[0], client.reqTemplate, SERVER_STREAM_REQUEST_NAME)
val stream: ResponseStream<Resp>
var sent = false
try {
// TODO: should this throw? Maybe not...
// An alternative would be to have it return a
// stream that throws the relevant exception in
// calls to receive.
stream = client.execute(msg, req.requestHeaders)
return client.execute(msg, req.requestHeaders) { stream ->
sent = true
val cancel = req.cancel
if (cancel is Cancel.AfterCloseSendMs) {
delay(cancel.millis.toLong())
stream.close()
}
streamResult(0, stream, cancel)
}
} catch (ex: Throwable) {
val connEx = if (ex is ConnectException) {
ex
} else {
ConnectException(
code = Code.UNKNOWN,
message = ex.message,
exception = ex,
if (!sent) {
val connEx = if (ex is ConnectException) {
ex
} else {
ConnectException(
code = Code.UNKNOWN,
message = ex.message,
exception = ex,
)
}
return ClientResponseResult(
error = connEx,
numUnsentRequests = 1,
)
}
return ClientResponseResult(
error = connEx,
numUnsentRequests = 1,
)
}
try {
val cancel = req.cancel
if (cancel is Cancel.AfterCloseSendMs) {
delay(cancel.millis.toLong())
stream.close()
}
return streamResult(0, stream, cancel)
} finally {
stream.close()
throw ex
}
}

Expand All @@ -272,8 +266,7 @@ class Client(
client: BidiStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult {
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
var numUnsent = 0
for (i in req.requestMessages.indices) {
if (req.requestDelayMs > 0) {
Expand All @@ -294,30 +287,27 @@ class Client(
val cancel = req.cancel
when (cancel) {
is Cancel.BeforeCloseSend -> {
stream.responses.close() // cancel
stream.close() // cancel
stream.requests.close() // close send
}
is Cancel.AfterCloseSendMs -> {
stream.requests.close() // close send
delay(cancel.millis.toLong())
stream.responses.close() // cancel
stream.close() // cancel
}
else -> {
stream.requests.close() // close send
}
}
return streamResult(numUnsent, stream.responses, cancel)
} finally {
stream.responses.close()
streamResult(numUnsent, stream.responses, cancel)
}
}

private suspend fun <Req : MessageLite, Resp : MessageLite> handleFullDuplexBidi(
client: BidiStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult {
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
val cancel = req.cancel
val payloads: MutableList<MessageLite> = mutableListOf()
for (i in req.requestMessages.indices) {
Expand All @@ -338,16 +328,16 @@ class Client(
// In full-duplex mode, we read the response after writing request,
// to interleave the requests and responses.
if (i == 0 && cancel is Cancel.AfterNumResponses && cancel.num == 0) {
stream.responses.close()
stream.close()
}
try {
val resp = stream.responses.messages.receive()
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.responses.close()
stream.close()
}
} catch (ex: ConnectException) {
return ClientResponseResult(
return@execute ClientResponseResult(
headers = stream.responses.headers(),
payloads = payloads,
error = ex,
Expand All @@ -358,13 +348,13 @@ class Client(
}
when (cancel) {
is Cancel.BeforeCloseSend -> {
stream.responses.close() // cancel
stream.close() // cancel
stream.requests.close() // close send
}
is Cancel.AfterCloseSendMs -> {
stream.requests.close() // close send
delay(cancel.millis.toLong())
stream.responses.close() // cancel
stream.close() // cancel
}
else -> {
stream.requests.close() // close send
Expand All @@ -378,22 +368,20 @@ class Client(
for (resp in stream.responses.messages) {
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.responses.close()
stream.close()
}
}
trailers = stream.responses.trailers()
} catch (ex: ConnectException) {
connEx = ex
trailers = ex.metadata
}
return ClientResponseResult(
ClientResponseResult(
headers = stream.responses.headers(),
payloads = payloads,
error = connEx,
trailers = trailers,
)
} finally {
stream.responses.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(
* @param Req The request message type
* @param Resp The response message type
*/
interface BidiStream<Req : MessageLite, Resp : MessageLite> {
interface BidiStream<Req : MessageLite, Resp : MessageLite> : Closeable {
val requests: RequestStream<Req>
val responses: ResponseStream<Resp>

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): BidiStream<Req, Resp> {
val reqStream = RequestStream.new(underlying)
Expand All @@ -56,8 +57,25 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(

override val responses: ResponseStream<Resp>
get() = respStream

override suspend fun close() {
responses.close()
}
}
}
}
}
}

/**
* Executes the bidirectional-stream call inside the given block.
* The block is used to send requests and receive responses. The
* stream is automatically closed when the block returns or throws.
*/
suspend fun <Req : MessageLite, Resp : MessageLite, R> BidiStreamClient<Req, Resp>.execute(
headers: Headers,
block: suspend (BidiStreamClient.BidiStream<Req, Resp>) -> R,
): R {
val stream = execute(headers)
return stream.use(block)
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
* @param Req The request message type
* @param Resp The response message type
*/
interface ClientStream<Req : MessageLite, Resp : MessageLite> {
interface ClientStream<Req : MessageLite, Resp : MessageLite> : Closeable {
suspend fun send(req: Req)
suspend fun closeAndReceive(): ResponseMessage<Resp>
suspend fun cancel()

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: ClientOnlyStreamInterface<Req, Resp>): ClientStream<Req, Resp> {
Expand Down Expand Up @@ -83,11 +82,24 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
}
}

override suspend fun cancel() {
override suspend fun close() {
underlying.cancel()
}
}
}
}
}
}

/**
* Executes the client-stream call inside the given block. The block
* is used to send the requests and then retrieve the responses. The
* stream is automatically closed when the block returns or throws.
*/
suspend fun <Req : MessageLite, Resp : MessageLite, R> ClientStreamClient<Req, Resp>.execute(
headers: Headers,
block: suspend (ClientStreamClient.ClientStream<Req, Resp>) -> R,
): R {
val stream = execute(headers)
return stream.use(block)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.connectrpc.conformance.client.adapt

// Like java.io.Closeable, but the close operation is suspendable.
interface Closeable {
suspend fun close()
}

// Like the standard kotlin "use" extension function, but uses
// a suspending Closeable instead of java.io.Closeable and accepts
// a suspending block.
internal suspend fun <T : Closeable, R> T.use(block: suspend (T) -> R): R {
var exception: Throwable? = null
try {
return block(this)
} catch (ex: Throwable) {
exception = ex
throw exception
} finally {
try {
this.close()
} catch (ex: Throwable) {
if (exception != null) {
exception.addSuppressed(ex)
} else {
throw ex
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package com.connectrpc.conformance.client.adapt
/**
* An RPC stub that allows for invoking RPC methods.
* Each method of Invoker corresponds to an RPC method
* and returns a client stub that can be used to actually
* invoke that RPC.
* of the conformance service and returns a client
* object that can be used to actually invoke that RPC.
*/
interface Invoker {
fun unaryClient(): UnaryClient<*, *>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@ import com.google.protobuf.MessageLite
* RequestStream is a stream that allows a client to upload
* zero or more request messages. When the client is done
* sending messages, it must close the stream.
*
* Note that closing the request stream is not strictly
* required if the RPC is cancelled or fails prematurely
* or if the response stream is closed first. Closing the
* requests "half-closes" the stream; closing the responses
* "fully closes" it.
*/
interface RequestStream<Req : MessageLite> {
interface RequestStream<Req : MessageLite> : Closeable {
/**
* Sends a message on the stream.
* @throws Exception when the request cannot be sent
* because of an error with the streaming call
*/
suspend fun send(req: Req)
fun close()

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): RequestStream<Req> {
Expand All @@ -36,7 +46,7 @@ interface RequestStream<Req : MessageLite> {
}
}

override fun close() {
override suspend fun close() {
underlying.sendClose()
}
}
Expand Down
Loading

0 comments on commit c576586

Please sign in to comment.