diff --git a/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt b/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt index b27c55ce..f69701d3 100644 --- a/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt +++ b/library/src/main/kotlin/com/connectrpc/UnaryBlockingCall.kt @@ -14,57 +14,18 @@ package com.connectrpc -import java.util.concurrent.CountDownLatch -import java.util.concurrent.atomic.AtomicReference - /** * A [UnaryBlockingCall] contains the way to make a blocking RPC call and cancelling the RPC. */ -class UnaryBlockingCall { - private var executable: ((ResponseMessage) -> Unit) -> Unit = { } - private var cancelFn: () -> Unit = { } - +interface UnaryBlockingCall { /** - * Execute the underlying request. - * Subsequent calls will create a new request. + * Execute the underlying request. Can only be called once. + * Subsequent calls will throw IllegalStateException. */ - fun execute(): ResponseMessage { - val countDownLatch = CountDownLatch(1) - val reference = AtomicReference>() - executable { responseMessage -> - reference.set(responseMessage) - countDownLatch.countDown() - } - countDownLatch.await() - return reference.get() - } + fun execute(): ResponseMessage /** * Cancel the underlying request. */ - fun cancel() { - cancelFn() - } - - /** - * Gives the blocking call a cancellation function to cancel the - * underlying request. - * - * @param cancel The function to call in order to cancel the - * underlying request. - */ - internal fun setCancel(cancel: () -> Unit) { - this.cancelFn = cancel - } - - /** - * Gives the blocking call the execution function to initiate - * the underlying request. - * - * @param executable The function to call in order to initiate - * a request. - */ - internal fun setExecute(executable: ((ResponseMessage) -> Unit) -> Unit) { - this.executable = executable - } + fun cancel() } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 78e3c540..bd0f234e 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -41,7 +41,6 @@ import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.withContext import okio.Buffer import java.net.URI -import java.util.concurrent.CountDownLatch import kotlin.coroutines.resume /** @@ -167,18 +166,9 @@ class ProtocolClient( headers: Headers, methodSpec: MethodSpec, ): UnaryBlockingCall { - val countDownLatch = CountDownLatch(1) - val call = UnaryBlockingCall() - // Set the unary synchronous executable. - call.setExecute { callback: (ResponseMessage) -> Unit -> - val cancellationFn = unary(request, headers, methodSpec) { responseMessage -> - callback(responseMessage) - countDownLatch.countDown() - } - // Set the cancellation function . - call.setCancel(cancellationFn) + return UnaryCall { callback -> + unary(request, headers, methodSpec, callback) } - return call } override suspend fun serverStream( diff --git a/library/src/main/kotlin/com/connectrpc/impl/UnaryCall.kt b/library/src/main/kotlin/com/connectrpc/impl/UnaryCall.kt new file mode 100644 index 00000000..a093d06b --- /dev/null +++ b/library/src/main/kotlin/com/connectrpc/impl/UnaryCall.kt @@ -0,0 +1,85 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc.impl + +import com.connectrpc.ResponseMessage +import com.connectrpc.UnaryBlockingCall +import com.connectrpc.http.Cancelable +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference + +/** + * Callback that handles asynchronous response. + */ +internal typealias ResponseCallback = (ResponseMessage) -> Unit + +/** + * Represents a cancelable asynchronous operation. When the function + * is invoked, the operation is initiated. When that operation completes + * it MUST invoke the callback, even when canceled. The value returned + * from the function can be called to abort the operation and have it + * return early. + */ +internal typealias AsyncOperation = (callback: ResponseCallback) -> Cancelable + +/** + * Concrete implementation of [UnaryBlockingCall] which transforms + * the given async operation into a synchronous/blocking one. + */ +internal class UnaryCall( + private val block: AsyncOperation, +) : UnaryBlockingCall { + private val executed = AtomicBoolean() + + /** + * initialized to null and then replaced with non-null + * function when [execute] or [cancel] is called. + */ + private var cancelFunc = AtomicReference() + + /** + * Execute the underlying operation and block until it completes. + */ + override fun execute(): ResponseMessage { + check(executed.compareAndSet(false, true)) { "already executed" } + + val resultReady = CountDownLatch(1) + val result = AtomicReference>() + val cancelFn = block { responseMessage -> + result.set(responseMessage) + resultReady.countDown() + } + + if (!cancelFunc.compareAndSet(null, cancelFn)) { + // concurrently cancelled before we could set the + // cancel function, so we need to cancel what we + // just started + cancelFn() + } + resultReady.await() + return result.get() + } + + /** + * Cancel the underlying request. + */ + override fun cancel() { + val cancelFn = cancelFunc.getAndSet {} // set to (non-null) no-op + if (cancelFn != null) { + cancelFn() + } + } +} diff --git a/library/src/test/kotlin/com/connectrpc/impl/UnaryCallTest.kt b/library/src/test/kotlin/com/connectrpc/impl/UnaryCallTest.kt new file mode 100644 index 00000000..99d77a29 --- /dev/null +++ b/library/src/test/kotlin/com/connectrpc/impl/UnaryCallTest.kt @@ -0,0 +1,106 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc.impl + +import com.connectrpc.Code +import com.connectrpc.ConnectException +import com.connectrpc.ResponseMessage +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors + +class UnaryCallTest { + @Test + fun testExecute() { + val executor = Executors.newSingleThreadExecutor() + try { + val result = Object() + val call = UnaryCall { callback -> + executor.execute { + callback.invoke( + ResponseMessage.Success( + result, + headers = emptyMap(), + trailers = emptyMap(), + ), + ) + } + return@UnaryCall { } + } + val resp = call.execute() + assertThat(resp).isInstanceOf(ResponseMessage.Success::class.java) + val msg = resp.success { it.message }!! + assertThat(msg).isEqualTo(result) + } finally { + assertThat(executor.shutdownNow()).isEmpty() + } + } + + @Test + fun testCancelAfterExecute() { + testCancel(false) + } + + @Test + fun testCancelBeforeExecute() { + testCancel(true) + } + + private fun testCancel(cancelFirst: Boolean) { + val executor = Executors.newFixedThreadPool(2) + try { + // Indicates when the async task has begun. + val taskRunning = CountDownLatch(1) + // Indicates when the async task has been canceled. + val taskCanceled = CountDownLatch(1) + + val call = UnaryCall { callback -> + executor.execute { + taskRunning.countDown() + taskCanceled.await() + callback.invoke( + ResponseMessage.Failure( + headers = emptyMap(), + trailers = emptyMap(), + cause = ConnectException(code = Code.CANCELED), + ), + ) + } + return@UnaryCall { + taskCanceled.countDown() + } + } + if (cancelFirst) { + // When we execute the task below, the call will observe + // that it has already been canceled and immediately + // cancel the just-started task. + call.cancel() + } else { + // This will cancel the task right after it has started running. + executor.execute { + taskRunning.await() + call.cancel() + } + } + val resp = call.execute() + assertThat(resp).isInstanceOf(ResponseMessage.Failure::class.java) + val connEx = resp.failure { it.cause }!! + assertThat(connEx.code).isEqualTo(Code.CANCELED) + } finally { + assertThat(executor.shutdownNow()).isEmpty() + } + } +}