From 9cc208852b820332f1e19513b6815413cff47a69 Mon Sep 17 00:00:00 2001 From: Youssef Shoaib Date: Fri, 1 Nov 2024 16:16:50 +0000 Subject: [PATCH] Fix non-local returns in Bracket.kt, and add relevant tests Also hide ResourceScopeImpl --- .../api/arrow-fx-coroutines.api | 15 +--- .../api/arrow-fx-coroutines.klib.api | 13 ++- .../kotlin/arrow/fx/coroutines/Bracket.kt | 84 +++++++------------ .../kotlin/arrow/fx/coroutines/Resource.kt | 17 +--- .../arrow/fx/coroutines/BracketCaseTest.kt | 57 +++++++++++-- 5 files changed, 86 insertions(+), 100 deletions(-) diff --git a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api index b805e377794..8842951eb7f 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api +++ b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api @@ -5,6 +5,8 @@ public final class arrow/fx/coroutines/AcquireStep { public final class arrow/fx/coroutines/BracketKt { public static final fun bracket (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun bracketCase (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun finalizeCase (Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; + public static final fun getErrorOrNull (Larrow/fx/coroutines/ExitCase;)Ljava/lang/Throwable; public static final fun guarantee (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun guaranteeCase (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun onCancel (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -263,19 +265,6 @@ public final class arrow/fx/coroutines/ResourceScope$DefaultImpls { public static fun releaseCase (Larrow/fx/coroutines/ResourceScope;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } -public final class arrow/fx/coroutines/ResourceScopeImpl : arrow/fx/coroutines/ResourceScope { - public fun ()V - public fun autoClose (Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; - public fun bind (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun cancelAll (Larrow/fx/coroutines/ExitCase;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun install (Ljava/lang/AutoCloseable;)Ljava/lang/AutoCloseable; - public fun install (Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun onClose (Lkotlin/jvm/functions/Function1;)V - public fun onRelease (Lkotlin/jvm/functions/Function2;)V - public fun release (Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun releaseCase (Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; -} - public abstract interface annotation class arrow/fx/coroutines/ScopeDSL : java/lang/annotation/Annotation { } diff --git a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.klib.api b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.klib.api index 69074c3bafb..0c0c0cd3592 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.klib.api +++ b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.klib.api @@ -65,13 +65,6 @@ final class arrow.fx.coroutines/CyclicBarrierCancellationException : kotlin.coro constructor () // arrow.fx.coroutines/CyclicBarrierCancellationException.|(){}[0] } -final class arrow.fx.coroutines/ResourceScopeImpl : arrow.fx.coroutines/ResourceScope { // arrow.fx.coroutines/ResourceScopeImpl|null[0] - constructor () // arrow.fx.coroutines/ResourceScopeImpl.|(){}[0] - - final fun onRelease(kotlin.coroutines/SuspendFunction1) // arrow.fx.coroutines/ResourceScopeImpl.onRelease|onRelease(kotlin.coroutines.SuspendFunction1){}[0] - final suspend fun cancelAll(arrow.fx.coroutines/ExitCase) // arrow.fx.coroutines/ResourceScopeImpl.cancelAll|cancelAll(arrow.fx.coroutines.ExitCase){}[0] -} - sealed class <#A: out kotlin/Any?, #B: out kotlin/Any?, #C: out kotlin/Any?> arrow.fx.coroutines/Race3 { // arrow.fx.coroutines/Race3|null[0] final inline fun <#A1: kotlin/Any?> fold(kotlin/Function1<#A, #A1>, kotlin/Function1<#B, #A1>, kotlin/Function1<#C, #A1>): #A1 // arrow.fx.coroutines/Race3.fold|fold(kotlin.Function1<1:0,0:0>;kotlin.Function1<1:1,0:0>;kotlin.Function1<1:2,0:0>){0§}[0] @@ -153,6 +146,9 @@ sealed class arrow.fx.coroutines/ExitCase { // arrow.fx.coroutines/ExitCase|null final object arrow.fx.coroutines/AcquireStep // arrow.fx.coroutines/AcquireStep|null[0] +final val arrow.fx.coroutines/errorOrNull // arrow.fx.coroutines/errorOrNull|@arrow.fx.coroutines.ExitCase{}errorOrNull[0] + final fun (arrow.fx.coroutines/ExitCase).(): kotlin/Throwable? // arrow.fx.coroutines/errorOrNull.|@arrow.fx.coroutines.ExitCase(){}[0] + final fun <#A: kotlin/Any?> (kotlin.coroutines/SuspendFunction1).arrow.fx.coroutines/asFlow(): kotlinx.coroutines.flow/Flow<#A> // arrow.fx.coroutines/asFlow|asFlow@kotlin.coroutines.SuspendFunction1(){0§}[0] final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/Flow<#A>).arrow.fx.coroutines/metered(kotlin.time/Duration): kotlinx.coroutines.flow/Flow<#A> // arrow.fx.coroutines/metered|metered@kotlinx.coroutines.flow.Flow<0:0>(kotlin.time.Duration){0§}[0] final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/Flow<#A>).arrow.fx.coroutines/metered(kotlin/Long): kotlinx.coroutines.flow/Flow<#A> // arrow.fx.coroutines/metered|metered@kotlinx.coroutines.flow.Flow<0:0>(kotlin.Long){0§}[0] @@ -166,6 +162,7 @@ final inline fun <#A: kotlin/Any?, #B: kotlin/Any?> (kotlinx.coroutines.flow/Flo final inline fun <#A: kotlin/Any?, #B: kotlin/Any?> (kotlinx.coroutines.flow/Flow<#A>).arrow.fx.coroutines/parMap(kotlin/Int = ..., crossinline kotlin.coroutines/SuspendFunction2): kotlinx.coroutines.flow/Flow<#B> // arrow.fx.coroutines/parMap|parMap@kotlinx.coroutines.flow.Flow<0:0>(kotlin.Int;kotlin.coroutines.SuspendFunction2){0§;1§}[0] final inline fun <#A: kotlin/Any?, #B: kotlin/Any?> (kotlinx.coroutines.flow/Flow<#A>).arrow.fx.coroutines/parMapNotNullUnordered(kotlin/Int = ..., crossinline kotlin.coroutines/SuspendFunction1<#A, #B?>): kotlinx.coroutines.flow/Flow<#B> // arrow.fx.coroutines/parMapNotNullUnordered|parMapNotNullUnordered@kotlinx.coroutines.flow.Flow<0:0>(kotlin.Int;kotlin.coroutines.SuspendFunction1<0:0,0:1?>){0§;1§}[0] final inline fun <#A: kotlin/Any?, #B: kotlin/Any?> (kotlinx.coroutines.flow/Flow<#A>).arrow.fx.coroutines/parMapUnordered(kotlin/Int = ..., crossinline kotlin.coroutines/SuspendFunction1<#A, #B>): kotlinx.coroutines.flow/Flow<#B> // arrow.fx.coroutines/parMapUnordered|parMapUnordered@kotlinx.coroutines.flow.Flow<0:0>(kotlin.Int;kotlin.coroutines.SuspendFunction1<0:0,0:1>){0§;1§}[0] +final inline fun <#A: kotlin/Any?> arrow.fx.coroutines/finalizeCase(kotlin/Function0<#A>, kotlin/Function1): #A // arrow.fx.coroutines/finalizeCase|finalizeCase(kotlin.Function0<0:0>;kotlin.Function1){0§}[0] final suspend fun <#A: kotlin/Any?, #B: kotlin/Any?, #C: kotlin/Any?> (kotlin.collections/Iterable<#B>).arrow.fx.coroutines/parMapOrAccumulate(kotlin.coroutines/CoroutineContext = ..., kotlin.coroutines/SuspendFunction2, #B, #C>): arrow.core/Either, kotlin.collections/List<#C>> // arrow.fx.coroutines/parMapOrAccumulate|parMapOrAccumulate@kotlin.collections.Iterable<0:1>(kotlin.coroutines.CoroutineContext;kotlin.coroutines.SuspendFunction2,0:1,0:2>){0§;1§;2§}[0] final suspend fun <#A: kotlin/Any?, #B: kotlin/Any?, #C: kotlin/Any?> (kotlin.collections/Iterable<#B>).arrow.fx.coroutines/parMapOrAccumulate(kotlin.coroutines/CoroutineContext = ..., kotlin/Function2<#A, #A, #A>, kotlin.coroutines/SuspendFunction2, #B, #C>): arrow.core/Either<#A, kotlin.collections/List<#C>> // arrow.fx.coroutines/parMapOrAccumulate|parMapOrAccumulate@kotlin.collections.Iterable<0:1>(kotlin.coroutines.CoroutineContext;kotlin.Function2<0:0,0:0,0:0>;kotlin.coroutines.SuspendFunction2,0:1,0:2>){0§;1§;2§}[0] final suspend fun <#A: kotlin/Any?, #B: kotlin/Any?, #C: kotlin/Any?> (kotlin.collections/Iterable<#B>).arrow.fx.coroutines/parMapOrAccumulate(kotlin.coroutines/CoroutineContext = ..., kotlin/Int, kotlin.coroutines/SuspendFunction2, #B, #C>): arrow.core/Either, kotlin.collections/List<#C>> // arrow.fx.coroutines/parMapOrAccumulate|parMapOrAccumulate@kotlin.collections.Iterable<0:1>(kotlin.coroutines.CoroutineContext;kotlin.Int;kotlin.coroutines.SuspendFunction2,0:1,0:2>){0§;1§;2§}[0] @@ -238,4 +235,4 @@ final suspend inline fun <#A: kotlin/Any?> arrow.fx.coroutines/guarantee(kotlin. final suspend inline fun <#A: kotlin/Any?> arrow.fx.coroutines/guaranteeCase(kotlin.coroutines/SuspendFunction0<#A>, crossinline kotlin.coroutines/SuspendFunction1): #A // arrow.fx.coroutines/guaranteeCase|guaranteeCase(kotlin.coroutines.SuspendFunction0<0:0>;kotlin.coroutines.SuspendFunction1){0§}[0] final suspend inline fun <#A: kotlin/Any?> arrow.fx.coroutines/onCancel(kotlin.coroutines/SuspendFunction0<#A>, crossinline kotlin.coroutines/SuspendFunction0): #A // arrow.fx.coroutines/onCancel|onCancel(kotlin.coroutines.SuspendFunction0<0:0>;kotlin.coroutines.SuspendFunction0){0§}[0] final suspend inline fun <#A: kotlin/Any?> arrow.fx.coroutines/resourceScope(kotlin.coroutines/SuspendFunction1): #A // arrow.fx.coroutines/resourceScope|resourceScope(kotlin.coroutines.SuspendFunction1){0§}[0] -final suspend inline fun arrow.fx.coroutines/runReleaseAndRethrow(kotlin/Throwable, crossinline kotlin.coroutines/SuspendFunction0): kotlin/Nothing // arrow.fx.coroutines/runReleaseAndRethrow|runReleaseAndRethrow(kotlin.Throwable;kotlin.coroutines.SuspendFunction0){}[0] +final suspend inline fun arrow.fx.coroutines/runReleaseAndRethrow(kotlin/Throwable?, crossinline kotlin.coroutines/SuspendFunction0) // arrow.fx.coroutines/runReleaseAndRethrow|runReleaseAndRethrow(kotlin.Throwable?;kotlin.coroutines.SuspendFunction0){}[0] diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Bracket.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Bracket.kt index 6c87f379a71..fe296f85744 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Bracket.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Bracket.kt @@ -20,7 +20,8 @@ public sealed class ExitCase { } } -internal val ExitCase.errorOrNull +@PublishedApi +internal val ExitCase.errorOrNull: Throwable? get() = when (this) { ExitCase.Completed -> null is ExitCase.Cancelled -> exception @@ -66,17 +67,7 @@ public suspend inline fun onCancel( public suspend inline fun guarantee( fa: suspend () -> A, crossinline finalizer: suspend () -> Unit -): A { - val res = try { - fa.invoke() - } catch (e: CancellationException) { - runReleaseAndRethrow(e) { finalizer() } - } catch (t: Throwable) { - runReleaseAndRethrow(t.nonFatalOrThrow()) { finalizer() } - } - withContext(NonCancellable) { finalizer() } - return res -} +): A = guaranteeCase(fa) { finalizer() } /** * Guarantees execution of a given [finalizer] after [fa] regardless of success, error or cancellation, allowing @@ -96,16 +87,8 @@ public suspend inline fun guarantee( public suspend inline fun guaranteeCase( fa: suspend () -> A, crossinline finalizer: suspend (ExitCase) -> Unit -): A { - val res = try { - fa() - } catch (e: CancellationException) { - runReleaseAndRethrow(e) { finalizer(ExitCase.Cancelled(e)) } - } catch (t: Throwable) { - runReleaseAndRethrow(t.nonFatalOrThrow()) { finalizer(ExitCase.Failure(t)) } - } - withContext(NonCancellable) { finalizer(ExitCase.Completed) } - return res +): A = finalizeCase({ fa() }) { ex -> + runReleaseAndRethrow(ex.errorOrNull) { finalizer(ex) } } /** @@ -153,22 +136,7 @@ public suspend inline fun bracket( crossinline acquire: suspend () -> A, use: suspend (A) -> B, crossinline release: suspend (A) -> Unit -): B { - val acquired = withContext(NonCancellable) { - acquire() - } - - val res = try { - use(acquired) - } catch (e: CancellationException) { - runReleaseAndRethrow(e) { release(acquired) } - } catch (t: Throwable) { - runReleaseAndRethrow(t.nonFatalOrThrow()) { release(acquired) } - } - - withContext(NonCancellable) { release(acquired) } - return res -} +): B = bracketCase(acquire, use) { acquired, _ -> release(acquired) } /** * A way to safely acquire a resource and release in the face of errors and cancellation. @@ -238,31 +206,35 @@ public suspend inline fun bracketCase( use: suspend (A) -> B, crossinline release: suspend (A, ExitCase) -> Unit ): B { - val acquired = withContext(NonCancellable) { - acquire() - } - - val res = try { - use(acquired) - } catch (e: CancellationException) { - runReleaseAndRethrow(e) { release(acquired, ExitCase.Cancelled(e)) } - } catch (t: Throwable) { - runReleaseAndRethrow(t.nonFatalOrThrow()) { release(acquired, ExitCase.Failure(t.nonFatalOrThrow())) } - } - - withContext(NonCancellable) { release(acquired, ExitCase.Completed) } - - return res + val acquired = withContext(NonCancellable) { acquire() } + return guaranteeCase({ use(acquired) }) { release(acquired, it) } } @PublishedApi -internal suspend inline fun runReleaseAndRethrow(original: Throwable, crossinline f: suspend () -> Unit): Nothing { +internal suspend inline fun runReleaseAndRethrow(original: Throwable?, crossinline f: suspend () -> Unit) { try { withContext(NonCancellable) { f() } } catch (e: Throwable) { - original.addSuppressed(e.nonFatalOrThrow()) + original?.addSuppressed(e.nonFatalOrThrow()) ?: throw e + } + original?.let { throw it } +} + +@PublishedApi +internal inline fun finalizeCase(block: () -> R, finalizer: (ExitCase) -> Unit): R { + var finished = false + return try { + block() + } catch (e: Throwable) { + finished = true + if (e !is CancellationException) e.nonFatalOrThrow() + finalizer(ExitCase.ExitCase(e)) + throw e + } finally { + if (!finished) { + finalizer(ExitCase.Completed) + } } - throw original } diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Resource.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Resource.kt index 19d2f092f02..e9659c9b50c 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Resource.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/Resource.kt @@ -351,21 +351,11 @@ public fun resource(block: suspend ResourceScope.() -> A): Resource = blo * */ @ScopeDSL +@OptIn(DelicateCoroutinesApi::class) @Suppress("REDUNDANT_INLINE_SUSPEND_FUNCTION_TYPE") public suspend inline fun resourceScope(action: suspend ResourceScope.() -> A): A { - val scope = ResourceScopeImpl() - var finished = false - return try { - action(scope) - } catch (e: Throwable) { - finished = true - scope.cancelAll(ExitCase.ExitCase(e)) - throw e - } finally { - if (!finished) { - scope.cancelAll(ExitCase.Completed) - } - } + val (scope, cancelAll) = resource { this }.allocated() + return finalizeCase({ scope.action() }) { cancelAll(it) } } @Suppress("REDUNDANT_INLINE_SUSPEND_FUNCTION_TYPE") @@ -471,7 +461,6 @@ public suspend fun Resource.allocated(): Pair Un bind() to this::cancelAll } -@PublishedApi internal class ResourceScopeImpl : ResourceScope { private val finalizers: Atomic Unit>> = Atomic(emptyList()) override fun onRelease(release: suspend (ExitCase) -> Unit) { diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/BracketCaseTest.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/BracketCaseTest.kt index b5a881959a1..217f0f4dd16 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/BracketCaseTest.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/BracketCaseTest.kt @@ -11,6 +11,7 @@ import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.async import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.time.ExperimentalTime @@ -54,7 +55,7 @@ class BracketCaseTest { bracketCase( acquire = { throw e }, use = { 5 }, - release = { _, _ -> Unit } + release = { _, _ -> } ) } should leftException(e) } @@ -67,7 +68,7 @@ class BracketCaseTest { bracketCase( acquire = { e.suspend() }, use = { 5 }, - release = { _, _ -> Unit } + release = { _, _ -> } ) } should leftException(e) } @@ -122,6 +123,44 @@ class BracketCaseTest { } } + @Test + fun bracketCaseMustRunReleaseTaskOnUseEarlyReturn() = runTest { + checkAll(10, Arb.int()) { i -> + val promise = CompletableDeferred() + + run { + bracketCase( + acquire = { i }, + use = { return@run it }, + release = { _, ex -> + require(promise.complete(ex)) { "Release should only be called once, called again with $ex" } + } + ) + } shouldBe i + + promise.await() shouldBe ExitCase.Completed + } + } + + @Test + fun bracketCaseMustRunReleaseTaskOnUseSuspendedEarlyReturn() = runTest { + checkAll(10, Arb.int()) { i -> + val promise = CompletableDeferred() + + run { + bracketCase( + acquire = { i }, + use = { return@run it.suspend() }, + release = { _, ex -> + require(promise.complete(ex)) { "Release should only be called once, called again with $ex" } + } + ) + } shouldBe i + + promise.await() shouldBe ExitCase.Completed + } + } + @Test fun bracketCaseMustRunReleaseTaskOnUseSuspendedError() = runTest { checkAll(10, Arb.int(), Arb.throwable()) { x, e -> @@ -218,7 +257,7 @@ class BracketCaseTest { use = { throw e }, release = { _, _ -> throw e2 } ) - } shouldBe Either.Left(e + e2) + } shouldBe Either.Left(e) } } @@ -231,7 +270,7 @@ class BracketCaseTest { use = { e.suspend() }, release = { _, _ -> throw e2 } ) - } shouldBe Either.Left(e + e2) + } shouldBe Either.Left(e) } } @@ -244,7 +283,7 @@ class BracketCaseTest { use = { throw e }, release = { _, _ -> e2.suspend() } ) - } shouldBe Either.Left(e + e2) + } shouldBe Either.Left(e) } } @@ -257,7 +296,7 @@ class BracketCaseTest { use = { e.suspend() }, release = { _, _ -> e2.suspend() } ) - } shouldBe Either.Left(e + e2) + } shouldBe Either.Left(e) } } @@ -318,7 +357,7 @@ class BracketCaseTest { val f = async { bracketCase( - acquire = { Unit }, + acquire = { }, use = { Unit.suspend() }, release = { _, exitCase -> require(exit.complete(exitCase)) { "Release should only be called once, called again with $exitCase" } @@ -357,7 +396,7 @@ class BracketCaseTest { // Wait until acquire started latch.await() - async { fiber.cancel() } + launch { fiber.cancel() } mVar.receive() shouldBe x mVar.receive() shouldBe y @@ -380,7 +419,7 @@ class BracketCaseTest { } latch.await() - async { fiber.cancel() } + launch { fiber.cancel() } mVar.receive() shouldBe x // If release was cancelled this hangs since the buffer is empty