Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide way to auto-close streams #213

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.GETConfiguration
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import java.security.KeyFactory
Expand Down Expand Up @@ -153,7 +151,7 @@ class Client(
private suspend fun <Req : MessageLite, Resp : MessageLite> handleClient(
client: ClientStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult = coroutineScope {
): ClientResponseResult {
if (req.streamType != StreamType.CLIENT_STREAM) {
throw RuntimeException("specified method ${req.method} is client-stream but stream type indicates ${req.streamType}")
}
Expand All @@ -163,8 +161,7 @@ class Client(
) {
throw RuntimeException("client stream calls can only support `BeforeCloseSend` and 'AfterCloseSendMs' cancellation field, instead got ${req.cancel!!::class.simpleName}")
}
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
var numUnsent = 0
for (i in req.requestMessages.indices) {
if (req.requestDelayMs > 0) {
Expand All @@ -184,22 +181,20 @@ class Client(
}
when (val cancel = req.cancel) {
is Cancel.BeforeCloseSend -> {
stream.cancel()
stream.close()
}
is Cancel.AfterCloseSendMs -> {
launch {
delay(cancel.millis.toLong())
stream.cancel()
stream.close()
}
}
else -> {
// We already validated the case above.
// So this case means no cancellation.
}
}
return@coroutineScope unaryResult(numUnsent, stream.closeAndReceive())
} finally {
stream.cancel()
unaryResult(numUnsent, stream.closeAndReceive())
}
}

Expand All @@ -220,37 +215,34 @@ class Client(
throw RuntimeException("server stream calls can only support `AfterCloseSendMs` and 'AfterNumResponses' cancellation field, instead got ${req.cancel!!::class.simpleName}")
}
val msg = fromAny(req.requestMessages[0], client.reqTemplate, SERVER_STREAM_REQUEST_NAME)
val stream: ResponseStream<Resp>
var sent = false
try {
// TODO: should this throw? Maybe not...
// An alternative would be to have it return a
// stream that throws the relevant exception in
// calls to receive.
stream = client.execute(msg, req.requestHeaders)
return client.execute(msg, req.requestHeaders) { stream ->
sent = true
val cancel = req.cancel
if (cancel is Cancel.AfterCloseSendMs) {
delay(cancel.millis.toLong())
stream.close()
}
streamResult(0, stream, cancel)
}
} catch (ex: Throwable) {
val connEx = if (ex is ConnectException) {
ex
} else {
ConnectException(
code = Code.UNKNOWN,
message = ex.message,
exception = ex,
if (!sent) {
val connEx = if (ex is ConnectException) {
ex
} else {
ConnectException(
code = Code.UNKNOWN,
message = ex.message,
exception = ex,
)
}
return ClientResponseResult(
error = connEx,
numUnsentRequests = 1,
)
}
return ClientResponseResult(
error = connEx,
numUnsentRequests = 1,
)
}
try {
val cancel = req.cancel
if (cancel is Cancel.AfterCloseSendMs) {
delay(cancel.millis.toLong())
stream.close()
}
return streamResult(0, stream, cancel)
} finally {
stream.close()
throw ex
}
}

Expand All @@ -272,8 +264,7 @@ class Client(
client: BidiStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult {
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
var numUnsent = 0
for (i in req.requestMessages.indices) {
if (req.requestDelayMs > 0) {
Expand All @@ -294,30 +285,27 @@ class Client(
val cancel = req.cancel
when (cancel) {
is Cancel.BeforeCloseSend -> {
stream.responses.close() // cancel
stream.close() // cancel
stream.requests.close() // close send
}
is Cancel.AfterCloseSendMs -> {
stream.requests.close() // close send
delay(cancel.millis.toLong())
stream.responses.close() // cancel
stream.close() // cancel
}
else -> {
stream.requests.close() // close send
}
}
return streamResult(numUnsent, stream.responses, cancel)
} finally {
stream.responses.close()
streamResult(numUnsent, stream.responses, cancel)
}
}

private suspend fun <Req : MessageLite, Resp : MessageLite> handleFullDuplexBidi(
client: BidiStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult {
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
val cancel = req.cancel
val payloads: MutableList<MessageLite> = mutableListOf()
for (i in req.requestMessages.indices) {
Expand All @@ -338,16 +326,16 @@ class Client(
// In full-duplex mode, we read the response after writing request,
// to interleave the requests and responses.
if (i == 0 && cancel is Cancel.AfterNumResponses && cancel.num == 0) {
stream.responses.close()
stream.close()
}
try {
val resp = stream.responses.messages.receive()
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.responses.close()
stream.close()
}
} catch (ex: ConnectException) {
return ClientResponseResult(
return@execute ClientResponseResult(
headers = stream.responses.headers(),
payloads = payloads,
error = ex,
Expand All @@ -358,13 +346,13 @@ class Client(
}
when (cancel) {
is Cancel.BeforeCloseSend -> {
stream.responses.close() // cancel
stream.close() // cancel
stream.requests.close() // close send
}
is Cancel.AfterCloseSendMs -> {
stream.requests.close() // close send
delay(cancel.millis.toLong())
stream.responses.close() // cancel
stream.close() // cancel
}
else -> {
stream.requests.close() // close send
Expand All @@ -378,22 +366,20 @@ class Client(
for (resp in stream.responses.messages) {
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.responses.close()
stream.close()
}
}
trailers = stream.responses.trailers()
} catch (ex: ConnectException) {
connEx = ex
trailers = ex.metadata
}
return ClientResponseResult(
ClientResponseResult(
headers = stream.responses.headers(),
payloads = payloads,
error = connEx,
trailers = trailers,
)
} finally {
stream.responses.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package com.connectrpc.conformance.client.adapt
import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.Headers
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope

/**
* The client of a bidi-stream RPC operation. A bidi-stream
Expand All @@ -35,17 +37,33 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(
val reqTemplate: Req,
val respTemplate: Resp,
) {
abstract suspend fun execute(headers: Headers): BidiStream<Req, Resp>
/**
* Executes the bidirectional-stream call inside the given block.
* The block is used to send requests and receive responses. The
* stream is automatically closed when the block returns or throws.
*/
suspend fun <R> execute(
headers: Headers,
block: suspend CoroutineScope.(BidiStream<Req, Resp>) -> R,
): R {
val stream = execute(headers)
return stream.use {
coroutineScope { block(this, it) }
}
}

protected abstract suspend fun execute(headers: Headers): BidiStream<Req, Resp>

/**
* A BidiStream combines a request stream and a response stream.
*
* @param Req The request message type
* @param Resp The response message type
*/
interface BidiStream<Req : MessageLite, Resp : MessageLite> {
interface BidiStream<Req : MessageLite, Resp : MessageLite> : Closeable {
val requests: RequestStream<Req>
val responses: ResponseStream<Resp>

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): BidiStream<Req, Resp> {
val reqStream = RequestStream.new(underlying)
Expand All @@ -56,6 +74,10 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(

override val responses: ResponseStream<Resp>
get() = respStream

override suspend fun close() {
responses.close()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import com.connectrpc.ConnectException
import com.connectrpc.Headers
import com.connectrpc.ResponseMessage
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope

/**
* The client of a client-stream RPC operation. A client-stream
Expand All @@ -34,7 +36,22 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
val reqTemplate: Req,
val respTemplate: Resp,
) {
abstract suspend fun execute(headers: Headers): ClientStream<Req, Resp>
/**
* Executes the client-stream call inside the given block. The block
* is used to send the requests and then retrieve the responses. The
* stream is automatically closed when the block returns or throws.
*/
suspend fun <R> execute(
headers: Headers,
block: suspend CoroutineScope.(ClientStream<Req, Resp>) -> R,
): R {
val stream = execute(headers)
return stream.use {
coroutineScope { block(this, it) }
}
}

protected abstract suspend fun execute(headers: Headers): ClientStream<Req, Resp>

/**
* A ClientStream is just like a RequestStream, except that closing
Expand All @@ -43,10 +60,9 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
* @param Req The request message type
* @param Resp The response message type
*/
interface ClientStream<Req : MessageLite, Resp : MessageLite> {
interface ClientStream<Req : MessageLite, Resp : MessageLite> : Closeable {
suspend fun send(req: Req)
suspend fun closeAndReceive(): ResponseMessage<Resp>
suspend fun cancel()

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: ClientOnlyStreamInterface<Req, Resp>): ClientStream<Req, Resp> {
Expand Down Expand Up @@ -83,7 +99,7 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
}
}

override suspend fun cancel() {
override suspend fun close() {
underlying.cancel()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2022-2023 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.connectrpc.conformance.client.adapt

// Like java.io.Closeable, but the close operation is suspendable.
interface Closeable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks great - might just want to name this SuspendCloseable or AsyncCloseable or something like that so people aren't confused with multiple Closeable types.

suspend fun close()
}

// Like the standard kotlin "use" extension function, but uses
// a suspending Closeable instead of java.io.Closeable and accepts
// a suspending block.
internal suspend fun <T : Closeable, R> T.use(block: suspend (T) -> R): R {
var exception: Throwable? = null
try {
return block(this)
} catch (ex: Throwable) {
exception = ex
throw exception
} finally {
try {
this.close()
} catch (ex: Throwable) {
if (exception != null) {
exception.addSuppressed(ex)
} else {
throw ex
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package com.connectrpc.conformance.client.adapt
/**
* An RPC stub that allows for invoking RPC methods.
* Each method of Invoker corresponds to an RPC method
* and returns a client stub that can be used to actually
* invoke that RPC.
* of the conformance service and returns a client
* object that can be used to actually invoke that RPC.
*/
interface Invoker {
fun unaryClient(): UnaryClient<*, *>
Expand Down
Loading
Loading