From 5addfb58e6535be85c43922c946701446c929fdb Mon Sep 17 00:00:00 2001 From: Flavio Brasil Date: Sat, 12 Oct 2024 18:20:53 -0700 Subject: [PATCH] core: Queue/Channel/Meter improvements and fixes --- .../src/main/scala/kyo/bench/Bench.scala | 2 +- .../scala/kyo/bench/CollectParBench.scala | 4 +- .../scala/kyo/bench/EnqueueDequeueBench.scala | 2 +- .../scala/kyo/bench/HttpClientBench.scala | 2 +- .../kyo/bench/HttpClientContentionBench.scala | 4 +- .../bench/HttpClientRaceContentionBench.scala | 4 +- .../src/main/scala/kyo/bench/MtlBench.scala | 12 +- .../main/scala/kyo/bench/PingPongBench.scala | 6 +- .../kyo/bench/ProducerConsumerBench.scala | 2 +- .../main/scala/kyo/bench/SemaphoreBench.scala | 2 +- .../kyo/bench/SemaphoreContentionBench.scala | 2 +- .../js/src/main/scala/kyo/queuesStubs.scala | 18 +- .../shared/src/main/scala/kyo/Channel.scala | 584 ++++++++--------- .../shared/src/main/scala/kyo/Closed.scala | 4 +- kyo-core/shared/src/main/scala/kyo/Hub.scala | 34 +- .../shared/src/main/scala/kyo/Meter.scala | 251 +++++--- .../shared/src/main/scala/kyo/Queue.scala | 604 ++++++++++-------- .../shared/src/main/scala/kyo/Resource.scala | 8 +- .../src/test/scala/kyo/ChannelTest.scala | 180 +++++- .../shared/src/test/scala/kyo/HubTest.scala | 30 +- .../shared/src/test/scala/kyo/MeterTest.scala | 122 +++- .../shared/src/test/scala/kyo/QueueTest.scala | 228 +++++-- .../shared/src/main/scala/kyo/Result.scala | 28 + .../src/test/scala/kyo/ResultTest.scala | 62 ++ .../main/scala/examples/ledger/db/Log.scala | 2 +- .../src/main/scala/kyo/PlatformBackend.scala | 2 +- .../src/main/scala/kyo/PlatformBackend.scala | 4 +- .../scala/sttp/client3/KyoSequencer.scala | 3 +- .../scala/sttp/client3/KyoSimpleQueue.scala | 5 +- .../shared/src/main/scala/kyo/Requests.scala | 6 +- .../src/test/scala/kyo/RequestsTest.scala | 7 +- 31 files changed, 1418 insertions(+), 806 deletions(-) diff --git a/kyo-bench/src/main/scala/kyo/bench/Bench.scala b/kyo-bench/src/main/scala/kyo/bench/Bench.scala index f0c9b9592..5251bb06a 100644 --- a/kyo-bench/src/main/scala/kyo/bench/Bench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/Bench.scala @@ -48,7 +48,7 @@ object Bench: abstract class Base[A](expectedResult: A) extends Bench[A](expectedResult): def zioBench(): zio.UIO[A] - def kyoBenchFiber(): kyo.<[A, kyo.Async] = kyoBench() + def kyoBenchFiber(): kyo.<[A, kyo.Async & kyo.Abort[Throwable]] = kyoBench() def kyoBench(): kyo.<[A, kyo.IO] def catsBench(): cats.effect.IO[A] end Base diff --git a/kyo-bench/src/main/scala/kyo/bench/CollectParBench.scala b/kyo-bench/src/main/scala/kyo/bench/CollectParBench.scala index bab9f2713..93018935f 100644 --- a/kyo-bench/src/main/scala/kyo/bench/CollectParBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/CollectParBench.scala @@ -10,7 +10,9 @@ class CollectParBench extends Bench.ForkOnly(Seq.fill(1000)(1)): override def kyoBenchFiber() = import kyo.* - Async.parallel(kyoTasks) + // TODO inference issue + val x = Async.parallel(kyoTasks) + x end kyoBenchFiber def catsBench() = diff --git a/kyo-bench/src/main/scala/kyo/bench/EnqueueDequeueBench.scala b/kyo-bench/src/main/scala/kyo/bench/EnqueueDequeueBench.scala index 511194284..ee9908600 100644 --- a/kyo-bench/src/main/scala/kyo/bench/EnqueueDequeueBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/EnqueueDequeueBench.scala @@ -22,7 +22,7 @@ class EnqueueDequeueBench extends Bench.ForkOnly(()): import kyo.Access - def loop(c: Channel[Unit], i: Int): Unit < Async = + def loop(c: Channel[Unit], i: Int): Unit < (Async & Abort[Closed]) = if i >= depth then IO.unit else diff --git a/kyo-bench/src/main/scala/kyo/bench/HttpClientBench.scala b/kyo-bench/src/main/scala/kyo/bench/HttpClientBench.scala index fc9d08067..2380ae06a 100644 --- a/kyo-bench/src/main/scala/kyo/bench/HttpClientBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/HttpClientBench.scala @@ -35,7 +35,7 @@ class HttpClientBench extends Bench.ForkOnly("pong"): override def kyoBenchFiber() = import kyo.* - Abort.run(Requests(_.get(kyoUrl))).map(_.getOrThrow) + Requests(_.get(kyoUrl)) end kyoBenchFiber val zioUrl = diff --git a/kyo-bench/src/main/scala/kyo/bench/HttpClientContentionBench.scala b/kyo-bench/src/main/scala/kyo/bench/HttpClientContentionBench.scala index b68e08644..39f56b763 100644 --- a/kyo-bench/src/main/scala/kyo/bench/HttpClientContentionBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/HttpClientContentionBench.scala @@ -37,7 +37,9 @@ class HttpClientContentionBench override def kyoBenchFiber() = import kyo.* - Async.parallel(Seq.fill(concurrency)(Abort.run(Requests(_.get(kyoUrl))).map(_.getOrThrow))) + // TODO inference issue + val x = Async.parallel(Seq.fill(concurrency)(Requests(_.get(kyoUrl)))) + x end kyoBenchFiber val zioUrl = diff --git a/kyo-bench/src/main/scala/kyo/bench/HttpClientRaceContentionBench.scala b/kyo-bench/src/main/scala/kyo/bench/HttpClientRaceContentionBench.scala index 1d6546f3d..9c23c97a6 100644 --- a/kyo-bench/src/main/scala/kyo/bench/HttpClientRaceContentionBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/HttpClientRaceContentionBench.scala @@ -58,7 +58,9 @@ class HttpClientRaceContentionBench override def kyoBenchFiber() = import kyo.* - Async.race(Seq.fill(concurrency)(Requests.let(kyoClient)(Abort.run(Requests(_.get(kyoUrl)))).map(_.getOrThrow))) + // TODO inference issue + val x = Async.race(Seq.fill(concurrency)(Requests.let(kyoClient)(Requests(_.get(kyoUrl))))) + x end kyoBenchFiber val zioUrl = diff --git a/kyo-bench/src/main/scala/kyo/bench/MtlBench.scala b/kyo-bench/src/main/scala/kyo/bench/MtlBench.scala index c71128922..510017e81 100644 --- a/kyo-bench/src/main/scala/kyo/bench/MtlBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/MtlBench.scala @@ -21,15 +21,13 @@ class MtlBench extends Bench(()): _ <- Var.update((state: State) => state.copy(value = state.value + 1)) yield () ) - Abort.run[Throwable]( - Var.run(State(2))( - Emit.run( - Env.run(EnvValue("config"))( - testKyo.andThen(Var.get[State]) - ) + Var.run(State(2))( + Emit.run( + Env.run(EnvValue("config"))( + testKyo.andThen(Var.get[State]) ) ) - ).eval + ) end syncKyo @Benchmark diff --git a/kyo-bench/src/main/scala/kyo/bench/PingPongBench.scala b/kyo-bench/src/main/scala/kyo/bench/PingPongBench.scala index ea7b4ed71..19e465fa9 100644 --- a/kyo-bench/src/main/scala/kyo/bench/PingPongBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/PingPongBench.scala @@ -37,11 +37,11 @@ class PingPongBench extends Bench.ForkOnly(()): override def kyoBenchFiber() = import kyo.* - def repeat[A](n: Int)(io: A < Async): A < Async = + def repeat[A](n: Int)(io: A < (Async & Abort[Closed])): A < (Async & Abort[Closed]) = if n <= 1 then io else io.flatMap(_ => repeat(n - 1)(io)) - def iterate(promise: Promise[Nothing, Unit], n: Int): Unit < Async = + def iterate(promise: Promise[Nothing, Unit], n: Int): Unit < (Async & Abort[Closed]) = for ref <- AtomicInt.init(n) chan <- Channel.init[Unit](1) @@ -52,7 +52,7 @@ class PingPongBench extends Bench.ForkOnly(()): n <- ref.decrementAndGet _ <- if n == 0 then promise.complete(Result.unit).unit else IO.unit yield () - _ <- repeat(depth)(Async.run[Nothing, Unit, Any](effect)) + _ <- repeat(depth)(Async.run[Closed, Unit, Any](effect)) yield () for diff --git a/kyo-bench/src/main/scala/kyo/bench/ProducerConsumerBench.scala b/kyo-bench/src/main/scala/kyo/bench/ProducerConsumerBench.scala index e1c06d18d..38b9fdf8e 100644 --- a/kyo-bench/src/main/scala/kyo/bench/ProducerConsumerBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/ProducerConsumerBench.scala @@ -30,7 +30,7 @@ class ProducerConsumerBench extends Bench.ForkOnly(()): import kyo.Access - def repeat[A](n: Int)(io: A < Async): A < Async = + def repeat[A](n: Int)(io: A < (Async & Abort[Closed])): A < (Async & Abort[Closed]) = if n <= 1 then io else io.flatMap(_ => repeat(n - 1)(io)) diff --git a/kyo-bench/src/main/scala/kyo/bench/SemaphoreBench.scala b/kyo-bench/src/main/scala/kyo/bench/SemaphoreBench.scala index 1e31dc7b1..93c8ca1fb 100644 --- a/kyo-bench/src/main/scala/kyo/bench/SemaphoreBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/SemaphoreBench.scala @@ -22,7 +22,7 @@ class SemaphoreBench extends Bench.ForkOnly(()): override def kyoBenchFiber() = import kyo.* - def loop(s: Meter, i: Int): Unit < Async = + def loop(s: Meter, i: Int): Unit < (Async & Abort[Closed]) = if i >= depth then IO.unit else diff --git a/kyo-bench/src/main/scala/kyo/bench/SemaphoreContentionBench.scala b/kyo-bench/src/main/scala/kyo/bench/SemaphoreContentionBench.scala index e104e6e81..22d7abdaf 100644 --- a/kyo-bench/src/main/scala/kyo/bench/SemaphoreContentionBench.scala +++ b/kyo-bench/src/main/scala/kyo/bench/SemaphoreContentionBench.scala @@ -40,7 +40,7 @@ class SemaphoreContentionBench extends Bench.ForkOnly(()): if n <= 1 then io else io.flatMap(_ => repeat(n - 1)(io)) - def loop(sem: Meter, cdl: Latch, i: Int = 0): Unit < Async = + def loop(sem: Meter, cdl: Latch, i: Int = 0): Unit < (Async & Abort[Closed]) = if i >= depth then cdl.release else diff --git a/kyo-core/js/src/main/scala/kyo/queuesStubs.scala b/kyo-core/js/src/main/scala/kyo/queuesStubs.scala index ba9e24ff9..624d51262 100644 --- a/kyo-core/js/src/main/scala/kyo/queuesStubs.scala +++ b/kyo-core/js/src/main/scala/kyo/queuesStubs.scala @@ -1,9 +1,21 @@ package org.jctools.queues import java.util.ArrayDeque +import scala.annotation.tailrec class StubQueue[A](capacity: Int) extends ArrayDeque[A]: def isFull = size() >= capacity + def drain(f: A => Unit): Unit = + given [B]: CanEqual[B, B] = CanEqual.derived + @tailrec def loop(): Unit = + super.poll() match + case null => + case value => + f(value) + loop() + end loop + loop() + end drain override def offer(e: A): Boolean = !isFull && super.offer(e) end StubQueue @@ -16,8 +28,8 @@ case class SpmcArrayQueue[A](capacity: Int) extends StubQueue[A](capacity) case class SpscArrayQueue[A](capacity: Int) extends StubQueue[A](capacity) -case class MpmcUnboundedXaddArrayQueue[A](chunkSize: Int) extends ArrayDeque[A] {} +case class MpmcUnboundedXaddArrayQueue[A](chunkSize: Int) extends StubQueue[A](Int.MaxValue) -case class MpscUnboundedArrayQueue[A](chunkSize: Int) extends ArrayDeque[A] {} +case class MpscUnboundedArrayQueue[A](chunkSize: Int) extends StubQueue[A](Int.MaxValue) -case class SpscUnboundedArrayQueue[A](chunkSize: Int) extends ArrayDeque[A] {} +case class SpscUnboundedArrayQueue[A](chunkSize: Int) extends StubQueue[A](Int.MaxValue) diff --git a/kyo-core/shared/src/main/scala/kyo/Channel.scala b/kyo-core/shared/src/main/scala/kyo/Channel.scala index 64a9ffa97..17641e7cb 100644 --- a/kyo-core/shared/src/main/scala/kyo/Channel.scala +++ b/kyo-core/shared/src/main/scala/kyo/Channel.scala @@ -1,5 +1,6 @@ package kyo +import org.jctools.queues.MessagePassingQueue.Consumer import org.jctools.queues.MpmcUnboundedXaddArrayQueue import scala.annotation.tailrec @@ -8,334 +9,295 @@ import scala.annotation.tailrec * @tparam A * The type of elements in the channel */ -abstract class Channel[A]: - self => +opaque type Channel[A] = Channel.Unsafe[A] - /** Returns the current size of the channel. - * - * @return - * The number of elements currently in the channel - */ - def size(using Frame): Int < IO - - /** Attempts to offer an element to the channel without blocking. - * - * @param v - * The element to offer - * @return - * true if the element was added to the channel, false otherwise - */ - def offer(v: A)(using Frame): Boolean < IO - - /** Offers an element to the channel without returning a result. - * - * @param v - * The element to offer - */ - def offerDiscard(v: A)(using Frame): Unit < IO - - /** Attempts to poll an element from the channel without blocking. - * - * @return - * Maybe containing the polled element, or empty if the channel is empty - */ - def poll(using Frame): Maybe[A] < IO - - private[kyo] def unsafePoll(using AllowUnsafe): Maybe[A] - - /** Checks if the channel is empty. - * - * @return - * true if the channel is empty, false otherwise - */ - def empty(using Frame): Boolean < IO - - /** Checks if the channel is full. - * - * @return - * true if the channel is full, false otherwise - */ - def full(using Frame): Boolean < IO - - /** Creates a fiber that puts an element into the channel. - * - * @param v - * The element to put - * @return - * A fiber that completes when the element is put into the channel - */ - def putFiber(v: A)(using Frame): Fiber[Nothing, Unit] < IO - - /** Creates a fiber that takes an element from the channel. - * - * @return - * A fiber that completes with the taken element - */ - def takeFiber(using Frame): Fiber[Nothing, A] < IO - - /** Puts an element into the channel, asynchronously blocking if necessary. - * - * @param v - * The element to put - */ - def put(v: A)(using Frame): Unit < Async - - /** Takes an element from the channel, asynchronously blocking if necessary. - * - * @return - * The taken element - */ - def take(using Frame): A < Async - - /** Checks if the channel is closed. - * - * @return - * true if the channel is closed, false otherwise - */ - def closed(using Frame): Boolean < IO +object Channel: - /** Drains all elements from the channel. - * - * @return - * A sequence containing all elements that were in the channel - */ - def drain(using Frame): Seq[A] < IO + extension [A](self: Channel[A]) + + /** Returns the capacity of the channel. + * + * @return + * The capacity of the channel + */ + def capacity: Int = self.capacity + + /** Returns the current size of the channel. + * + * @return + * The number of elements currently in the channel + */ + def size(using Frame): Int < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.size())) + + /** Attempts to offer an element to the channel without blocking. + * + * @param value + * The element to offer + * @return + * true if the element was added to the channel, false otherwise + */ + def offer(value: A)(using Frame): Boolean < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.offer(value))) + + /** Offers an element to the channel without returning a result. + * + * @param v + * The element to offer + */ + def offerDiscard(value: A)(using Frame): Unit < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.offer(value).unit)) + + /** Attempts to poll an element from the channel without blocking. + * + * @return + * Maybe containing the polled element, or empty if the channel is empty + */ + def poll(using Frame): Maybe[A] < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.poll())) + + /** Puts an element into the channel, asynchronously blocking if necessary. + * + * @param value + * The element to put + */ + def put(value: A)(using Frame): Unit < (Abort[Closed] & Async) = + IO.Unsafe { + self.offer(value).fold(Abort.error) { + case true => () + case false => self.putFiber(value).safe.get + } + } - /** Closes the channel. - * - * @return - * Maybe containing a sequence of remaining elements, or empty if the channel was already closed - */ - def close(using Frame): Maybe[Seq[A]] < IO -end Channel + /** Takes an element from the channel, asynchronously blocking if necessary. + * + * @return + * The taken element + */ + def take(using Frame): A < (Abort[Closed] & Async) = + IO.Unsafe { + self.poll().fold(Abort.error) { + case Present(value) => value + case Absent => self.takeFiber().safe.get + } + } -object Channel: + /** Creates a fiber that puts an element into the channel. + * + * @param value + * The element to put + * @return + * A fiber that completes when the element is put into the channel + */ + def putFiber(value: A)(using Frame): Fiber[Closed, Unit] < IO = IO.Unsafe(self.putFiber(value).safe) + + /** Creates a fiber that takes an element from the channel. + * + * @return + * A fiber that completes with the taken element + */ + def takeFiber(using Frame): Fiber[Closed, A] < IO = IO.Unsafe(self.takeFiber().safe) + + /** Drains all elements from the channel. + * + * @return + * A sequence containing all elements that were in the channel + */ + def drain(using Frame): Seq[A] < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.drain())) + + /** Closes the channel. + * + * @return + * A sequence of remaining elements + */ + def close(using Frame): Maybe[Seq[A]] < IO = IO.Unsafe(self.close()) + + /** Checks if the channel is closed. + * + * @return + * true if the channel is closed, false otherwise + */ + def closed(using Frame): Boolean < IO = IO.Unsafe(self.closed()) + + /** Checks if the channel is empty. + * + * @return + * true if the channel is empty, false otherwise + */ + def empty(using Frame): Boolean < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.empty())) + + /** Checks if the channel is full. + * + * @return + * true if the channel is full, false otherwise + */ + def full(using Frame): Boolean < (Abort[Closed] & IO) = IO.Unsafe(Abort.get(self.full())) + + def unsafe: Unsafe[A] = self + end extension /** Initializes a new Channel. * * @param capacity - * The capacity of the channel + * The capacity of the channel. Note that this will be rounded up to the next power of two. * @param access * The access mode for the channel (default is MPMC) - * @param initFrame - * The initial frame for the channel * @tparam A * The type of elements in the channel * @return * A new Channel instance + * + * @note + * The actual capacity will be rounded up to the next power of two. + * @warning + * The actual capacity may be larger than the specified capacity due to rounding. */ - def init[A]( - capacity: Int, - access: Access = Access.MultiProducerMultiConsumer - )(using initFrame: Frame): Channel[A] < IO = - Queue.init[A](capacity, access).map { queue => - IO { - new Channel[A]: - - def u = queue.unsafe - val takes = new MpmcUnboundedXaddArrayQueue[Promise.Unsafe[Nothing, A]](8) - val puts = new MpmcUnboundedXaddArrayQueue[(A, Promise.Unsafe[Nothing, Unit])](8) - - def size(using Frame) = op(u.size()) - def empty(using Frame) = op(u.empty()) - def full(using Frame) = op(u.full()) - - def offer(v: A)(using Frame) = - IO.Unsafe { - !u.closed() && { - try u.offer(v) - finally flush() - } - } - - def offerDiscard(v: A)(using Frame) = - IO.Unsafe { - if !u.closed() then - try discard(u.offer(v)) - finally flush() - } - - def unsafePoll(using AllowUnsafe): Maybe[A] = - if u.closed() then - Maybe.empty - else - try u.poll() - finally flush() - - def poll(using Frame) = - IO.Unsafe(unsafePoll) - - def put(v: A)(using Frame) = - IO.Unsafe { - try - if u.closed() then - throw closedException - else if u.offer(v) then - () - else - val p = Promise.Unsafe.init[Nothing, Unit]() - puts.add((v, p)) - p.safe.get - end if - finally - flush() - } - - def putFiber(v: A)(using frame: Frame) = - IO.Unsafe { - try - if u.closed() then - throw closedException - else if u.offer(v) then - Fiber.unit - else - val p = Promise.Unsafe.init[Nothing, Unit]() - puts.add((v, p)) - p.safe - end if - finally - flush() - } - - def take(using Frame) = - IO.Unsafe { - try - if u.closed() then - throw closedException - else - u.poll() match - case Absent => - val p = Promise.Unsafe.init[Nothing, A]() - takes.add(p) - p.safe.get - case Present(v) => - v - finally - flush() - } - - def takeFiber(using frame: Frame) = - IO.Unsafe { - try - if u.closed() then - throw closedException - else - u.poll() match - case Absent => - val p = Promise.Unsafe.init[Nothing, A]() - takes.add(p) - p.safe - case Present(v) => - Fiber.success(v) - finally - flush() + def init[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using Frame): Channel[A] < IO = + IO.Unsafe(Unsafe.init(capacity, access)) + + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + abstract class Unsafe[A]: + def capacity: Int + def size()(using AllowUnsafe): Result[Closed, Int] + + def offer(value: A)(using AllowUnsafe): Result[Closed, Boolean] + def poll()(using AllowUnsafe): Result[Closed, Maybe[A]] + + def putFiber(value: A)(using AllowUnsafe): Fiber.Unsafe[Closed, Unit] + def takeFiber()(using AllowUnsafe): Fiber.Unsafe[Closed, A] + + def drain()(using AllowUnsafe): Result[Closed, Seq[A]] + def close()(using Frame, AllowUnsafe): Maybe[Seq[A]] + + def empty()(using AllowUnsafe): Result[Closed, Boolean] + def full()(using AllowUnsafe): Result[Closed, Boolean] + def closed()(using AllowUnsafe): Boolean + + def safe: Channel[A] = this + end Unsafe + + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + object Unsafe: + def init[A]( + _capacity: Int, + access: Access = Access.MultiProducerMultiConsumer + )(using initFrame: Frame, allow: AllowUnsafe): Unsafe[A] = + new Unsafe[A]: + import AllowUnsafe.embrace.danger + + val queue = Queue.Unsafe.init[A](_capacity, access) + val takes = new MpmcUnboundedXaddArrayQueue[Promise.Unsafe[Closed, A]](8) + val puts = new MpmcUnboundedXaddArrayQueue[(A, Promise.Unsafe[Closed, Unit])](8) + + def capacity = _capacity + + def size()(using AllowUnsafe) = queue.size() + + def offer(value: A)(using AllowUnsafe) = + val result = queue.offer(value) + if result.contains(true) then flush() + result + end offer + + def poll()(using AllowUnsafe) = + val result = queue.poll() + if result.exists(_.nonEmpty) then flush() + result + end poll + + def putFiber(value: A)(using AllowUnsafe): Fiber.Unsafe[Closed, Unit] = + val promise = Promise.Unsafe.init[Closed, Unit]() + val tuple = (value, promise) + puts.add(tuple) + flush() + promise + end putFiber + + def takeFiber()(using AllowUnsafe): Fiber.Unsafe[Closed, A] = + val promise = Promise.Unsafe.init[Closed, A]() + takes.add(promise) + flush() + promise + end takeFiber + + def drain()(using AllowUnsafe) = + val result = queue.drain() + if result.exists(_.nonEmpty) then flush() + result + end drain + + def close()(using frame: Frame, allow: AllowUnsafe) = + queue.close().map { backlog => + flush() + backlog + } + + def empty()(using AllowUnsafe) = queue.empty() + def full()(using AllowUnsafe) = queue.full() + def closed()(using AllowUnsafe) = queue.closed() + + @tailrec private def flush(): Unit = + import AllowUnsafe.embrace.danger + + // This method ensures that all values are processed + // and handles interrupted fibers by discarding them. + val queueClosed = queue.closed() + val queueSize = queue.size().getOrElse(0) + val takesEmpty = takes.isEmpty() + val putsEmpty = puts.isEmpty() + + if queueClosed && (!takesEmpty || !putsEmpty) then + // Queue is closed, drain all takes and puts + val fail = queue.size() // Obtain the failed Result + takes.drain(_.completeDiscard(fail)) + puts.drain(_._2.completeDiscard(fail.unit)) + flush() + else if queueSize > 0 && !takesEmpty then + // Attempt to transfer a value from the queue to + // a waiting take opeation. + Maybe(takes.poll()).foreach { promise => + queue.poll() match + case Result.Success(Present(value)) => + if !promise.complete(Result.success(value)) && !queue.offer(value).contains(true) then + // If completing the take fails and the queue + // cannot accept the value back, enqueue a + // placeholder put operation + val placeholder = Promise.Unsafe.init[Nothing, Unit]() + discard(puts.add((value, placeholder))) + case _ => + // Queue became empty, enqueue the take again + discard(takes.add(promise)) } - - def closedException(using frame: Frame): Closed = Closed("Channel", initFrame, frame) - - inline def op[A](inline v: AllowUnsafe ?=> A)(using inline frame: Frame): A < IO = - IO.Unsafe { - if u.closed() then - throw closedException + flush() + else if queueSize < capacity && !putsEmpty then + // Attempt to transfer a value from a waiting + // put operation to the queue. + Maybe(puts.poll()).foreach { tuple => + val (value, promise) = tuple + if queue.offer(value).contains(true) then + // Queue accepted the value, complete the put + discard(promise.complete(Result.unit)) else - v + // Queue became full, enqueue the put again + discard(puts.add(tuple)) + end if } - - def closed(using Frame) = queue.closed - - def drain(using Frame) = queue.drain - - def close(using frame: Frame) = - IO.Unsafe { - u.close() match - case Absent => Maybe.empty - case r => - val c = Result.panic(closedException) - def dropTakes(): Unit = - val p = takes.poll() - if !isNull(p) then - p.completeDiscard(c) - dropTakes() - end dropTakes - def dropPuts(): Unit = - puts.poll() match - case null => () - case (_, p) => - p.completeDiscard(c) - dropPuts() - dropTakes() - dropPuts() - r + flush() + else if queueSize == 0 && !putsEmpty && !takesEmpty then + // Directly transfer a value from a producer to a + // consumer when the queue is empty. + Maybe(puts.poll()).foreach { putTuple => + val (value, putPromise) = putTuple + Maybe(takes.poll()) match + case Present(takePromise) if takePromise.complete(Result.success(value)) => + // Value transfered to the pending take, complete put + putPromise.completeDiscard(Result.unit) + case _ => + // No pending take available or the pending take is already + // completed due to interruption. Enqueue the put again. + discard(puts.add(putTuple)) + end match } - - @tailrec private def flush(): Unit = - import AllowUnsafe.embrace.danger - - // This method ensures that all values are processed - // and handles interrupted fibers by discarding them. - val queueSize = u.size() - val takesEmpty = takes.isEmpty() - val putsEmpty = puts.isEmpty() - - if queueSize > 0 && !takesEmpty then - // Attempt to transfer a value from the queue to - // a waiting consumer (take). - val p = takes.poll() - if !isNull(p) then - u.poll() match - case Absent => - // If the queue has been emptied before the - // transfer, requeue the consumer's promise. - discard(takes.add(p)) - case Present(v) => - if !p.complete(Result.success(v)) && !u.offer(v) then - // If completing the take fails and the queue - // cannot accept the value back, enqueue a - // placeholder put operation to preserve the value. - val placeholder = Promise.Unsafe.init[Nothing, Unit]() - discard(puts.add((v, placeholder))) - end if - flush() - else if queueSize < capacity && !putsEmpty then - // Attempt to transfer a value from a waiting - // producer (put) to the queue. - val t = puts.poll() - if t != null then - val (v, p) = t - if u.offer(v) then - // Complete the put's promise if the value is - // successfully enqueued. If the fiber became - // interrupted, the completion will be ignored. - discard(p.complete(Result.success(()))) - else - // If the queue becomes full before the transfer, - // requeue the producer's operation. - discard(puts.add(t)) - end if - end if - flush() - else if queueSize == 0 && !putsEmpty && !takesEmpty then - // Directly transfer a value from a producer to a - // consumer when the queue is empty. - val t = puts.poll() - if t != null then - val (v, p) = t - val p2 = takes.poll() - if !isNull(p2) && p2.complete(Result.success(v)) then - // If the transfer is successful, complete - // the put's promise. If the consumer's fiber - // became interrupted, the completion will be - // ignored. - discard(p.complete(Result.success(()))) - else - // If the transfer to the consumer fails, requeue - // the producer's operation. - discard(puts.add(t)) - end if - end if - flush() - end if - end flush - } - } + flush() + end if + end flush + end new + end init + end Unsafe end Channel diff --git a/kyo-core/shared/src/main/scala/kyo/Closed.scala b/kyo-core/shared/src/main/scala/kyo/Closed.scala index 090cc7974..2c2d3e55f 100644 --- a/kyo-core/shared/src/main/scala/kyo/Closed.scala +++ b/kyo-core/shared/src/main/scala/kyo/Closed.scala @@ -2,6 +2,6 @@ package kyo import scala.util.control.NoStackTrace -case class Closed(resource: String, createdAt: Frame, failedAt: Frame) - extends Exception(s"$resource created at ${createdAt.parse.position} is closed. Failure at ${failedAt.parse.position}") +case class Closed(message: String, createdAt: Frame, failedAt: Frame) + extends Exception(s"Resource created at ${createdAt.parse.position} is closed. Failure at ${failedAt.parse.position}: $message") with NoStackTrace diff --git a/kyo-core/shared/src/main/scala/kyo/Hub.scala b/kyo-core/shared/src/main/scala/kyo/Hub.scala index e832c315a..80fb0a803 100644 --- a/kyo-core/shared/src/main/scala/kyo/Hub.scala +++ b/kyo-core/shared/src/main/scala/kyo/Hub.scala @@ -10,7 +10,7 @@ import java.util.concurrent.CopyOnWriteArraySet */ class Hub[A] private[kyo] ( ch: Channel[A], - fiber: Fiber[Nothing, Unit], + fiber: Fiber[Closed, Unit], listeners: CopyOnWriteArraySet[Channel[A]] )(using initFrame: Frame): @@ -19,7 +19,7 @@ class Hub[A] private[kyo] ( * @return * the number of elements currently in the Hub */ - def size(using Frame): Int < IO = ch.size + def size(using Frame): Int < (IO & Abort[Closed]) = ch.size /** Attempts to offer an element to the Hub without blocking. * @@ -28,28 +28,28 @@ class Hub[A] private[kyo] ( * @return * true if the element was added, false otherwise */ - def offer(v: A)(using Frame): Boolean < IO = ch.offer(v) + def offer(v: A)(using Frame): Boolean < (IO & Abort[Closed]) = ch.offer(v) /** Offers an element to the Hub without returning a result. * * @param v * the element to offer */ - def offerDiscard(v: A)(using Frame): Unit < IO = ch.offerDiscard(v) + def offerDiscard(v: A)(using Frame): Unit < (IO & Abort[Closed]) = ch.offerDiscard(v) /** Checks if the Hub is empty. * * @return * true if the Hub is empty, false otherwise */ - def empty(using Frame): Boolean < IO = ch.empty + def empty(using Frame): Boolean < (IO & Abort[Closed]) = ch.empty /** Checks if the Hub is full. * * @return * true if the Hub is full, false otherwise */ - def full(using Frame): Boolean < IO = ch.full + def full(using Frame): Boolean < (IO & Abort[Closed]) = ch.full /** Creates a fiber that puts an element into the Hub. * @@ -58,14 +58,14 @@ class Hub[A] private[kyo] ( * @return * a Fiber that, when run, will put the element into the Hub */ - def putFiber(v: A)(using Frame): Fiber[Nothing, Unit] < IO = ch.putFiber(v) + def putFiber(v: A)(using Frame): Fiber[Closed, Unit] < IO = ch.putFiber(v) /** Puts an element into the Hub, potentially blocking if the Hub is full. * * @param v * the element to put */ - def put(v: A)(using Frame): Unit < Async = ch.put(v) + def put(v: A)(using Frame): Unit < (Async & Abort[Closed]) = ch.put(v) /** Checks if the Hub is closed. * @@ -100,7 +100,7 @@ class Hub[A] private[kyo] ( * @return * a new Listener */ - def listen(using Frame): Listener[A] < IO = + def listen(using Frame): Listener[A] < (IO & Abort[Closed]) = listen(0) /** Creates a new listener for this Hub with specified buffer size. @@ -110,8 +110,8 @@ class Hub[A] private[kyo] ( * @return * a new Listener */ - def listen(bufferSize: Int)(using frame: Frame): Listener[A] < IO = - def fail = IO(throw Closed("Hub", initFrame, frame)) + def listen(bufferSize: Int)(using frame: Frame): Listener[A] < (IO & Abort[Closed]) = + def fail = Abort.fail(Closed("Hub", initFrame, frame)) closed.map { case true => fail case false => @@ -185,42 +185,42 @@ object Hub: * @return * the number of elements currently in the Listener's buffer */ - def size(using Frame): Int < IO = child.size + def size(using Frame): Int < (IO & Abort[Closed]) = child.size /** Checks if the Listener's buffer is empty. * * @return * true if the Listener's buffer is empty, false otherwise */ - def empty(using Frame): Boolean < IO = child.empty + def empty(using Frame): Boolean < (IO & Abort[Closed]) = child.empty /** Checks if the Listener's buffer is full. * * @return * true if the Listener's buffer is full, false otherwise */ - def full(using Frame): Boolean < IO = child.full + def full(using Frame): Boolean < (IO & Abort[Closed]) = child.full /** Attempts to retrieve and remove the head of the Listener's buffer without blocking. * * @return * a Maybe containing the head element if available, or empty if the buffer is empty */ - def poll(using Frame): Maybe[A] < IO = child.poll + def poll(using Frame): Maybe[A] < (IO & Abort[Closed]) = child.poll /** Creates a fiber that takes an element from the Listener's buffer. * * @return * a Fiber that, when run, will take an element from the Listener's buffer */ - def takeFiber(using Frame): Fiber[Nothing, A] < IO = child.takeFiber + def takeFiber(using Frame): Fiber[Closed, A] < IO = child.takeFiber /** Takes an element from the Listener's buffer, potentially blocking if the buffer is empty. * * @return * the next element from the Listener's buffer */ - def take(using Frame): A < Async = child.take + def take(using Frame): A < (Async & Abort[Closed]) = child.take /** Checks if the Listener is closed. * diff --git a/kyo-core/shared/src/main/scala/kyo/Meter.scala b/kyo-core/shared/src/main/scala/kyo/Meter.scala index 1b31a396a..1d56ebc0a 100644 --- a/kyo-core/shared/src/main/scala/kyo/Meter.scala +++ b/kyo-core/shared/src/main/scala/kyo/Meter.scala @@ -1,25 +1,13 @@ package kyo +import org.jctools.queues.MpmcUnboundedXaddArrayQueue +import scala.annotation.tailrec + /** A Meter is an abstract class that represents a mechanism for controlling concurrency and rate limiting. */ abstract class Meter: self => - /** Returns the number of available permits. - * - * @return - * The number of available permits as an Int effect. - */ - def available(using Frame): Int < IO - - /** Checks if there are any available permits. - * - * @return - * A Boolean effect indicating whether permits are available. - */ - def isAvailable(using Frame): Boolean < IO = - available.map(_ > 0) - /** Runs an effect after acquiring a permit. * * @param v @@ -31,7 +19,7 @@ abstract class Meter: * @return * The result of running the effect. */ - def run[A, S](v: => A < S)(using Frame): A < (S & Async) + def run[A, S](v: => A < S)(using Frame): A < (S & Async & Abort[Closed]) /** Attempts to run an effect if a permit is available. * @@ -44,7 +32,21 @@ abstract class Meter: * @return * A Maybe containing the result of running the effect, or Absent if no permit was available. */ - def tryRun[A, S](v: => A < S)(using Frame): Maybe[A] < (IO & S) + def tryRun[A, S](v: => A < S)(using Frame): Maybe[A] < (S & Async & Abort[Closed]) + + /** Returns the number of available permits. + * + * @return + * The number of available permits. + */ + def availablePermits(using Frame): Int < (Async & Abort[Closed]) + + /** Returns the number of fibers waiting for a permit. + * + * @return + * The number of fibers waiting for a permit. + */ + def pendingWaiters(using Frame): Int < (Async & Abort[Closed]) /** Closes the Meter. * @@ -52,16 +54,26 @@ abstract class Meter: * A Boolean effect indicating whether the Meter was successfully closed. */ def close(using Frame): Boolean < IO + + /** Checks if the Meter is closed. + * + * @return + * A Boolean effect indicating whether the Meter is closed. + */ + def closed(using Frame): Boolean < IO + end Meter object Meter: - /** A no-op Meter that always allows operations. */ + /** A no-op Meter that always allows operations and can't be closed. */ case object Noop extends Meter: - def available(using Frame) = Int.MaxValue + def availablePermits(using Frame) = Int.MaxValue + def pendingWaiters(using Frame) = 0 def run[A, S](v: => A < S)(using Frame) = v def tryRun[A, S](v: => A < S)(using Frame) = v.map(Maybe(_)) def close(using Frame) = false + def closed(using Frame): Boolean < IO = false end Noop /** Creates a Meter that acts as a mutex (binary semaphore). @@ -80,30 +92,11 @@ object Meter: * A Meter effect that represents a semaphore. */ def initSemaphore(concurrency: Int)(using Frame): Meter < IO = - Channel.init[Unit](concurrency).map { chan => - offer(concurrency, chan, ()).map { _ => - new Meter: - def available(using Frame) = chan.size - def release(using Frame) = chan.offerDiscard(()) - - def run[A, S](v: => A < S)(using Frame) = - IO.ensure(release) { - chan.take.andThen(v) - } - - def tryRun[A, S](v: => A < S)(using Frame) = - IO.Unsafe { - chan.unsafePoll match - case Absent => Maybe.empty - case _ => - IO.ensure(release) { - v.map(Maybe(_)) - } - } - - def close(using Frame) = - chan.close.map(_.isDefined) - } + IO.Unsafe { + new Base(concurrency): + def dispatch[A, S](v: => A < S) = + // Release the permit right after the computation + IO.ensure(release())(v) } /** Creates a Meter that acts as a rate limiter. @@ -115,25 +108,23 @@ object Meter: * @return * A Meter effect that represents a rate limiter. */ - def initRateLimiter(rate: Int, period: Duration)(using Frame): Meter < IO = - Channel.init[Unit](rate).map { chan => - Timer.scheduleAtFixedRate(period)(offer(rate, chan, ())).map { _ => - new Meter: - - def available(using Frame) = chan.size - def run[A, S](v: => A < S)(using Frame) = chan.take.map(_ => v) - - def tryRun[A, S](v: => A < S)(using Frame) = - chan.poll.map { - case Absent => - Maybe.empty - case _ => - v.map(Maybe(_)) - } - - def close(using Frame) = - chan.close.map(_.isDefined) - } + def initRateLimiter(rate: Int, period: Duration)(using initFrame: Frame): Meter < IO = + IO.Unsafe { + new Base(rate): + val timerTask = + // Schedule periodic task to replenish permits + Timer.live.unsafe.scheduleAtFixedRate(period, period)(replenish()) + + def dispatch[A, S](v: => A < S) = + // Don't release a permit since it's managed by the timer task + v + + @tailrec def replenish(i: Int = 0): Unit = + if i < rate then + release() + replenish(i + 1) + + override def onClose() = discard(timerTask.cancel()) } /** Combines two Meters into a pipeline. @@ -198,22 +189,27 @@ object Meter: Kyo.collect(meters).map { seq => val meters = seq.toIndexedSeq new Meter: + def availablePermits(using Frame) = + Loop.indexed(0) { (idx, acc) => + if idx == meters.length then Loop.done(acc) + else meters(idx).availablePermits.map(v => Loop.continue(acc + v)) + } - def available(using Frame) = + def pendingWaiters(using Frame) = Loop.indexed(0) { (idx, acc) => if idx == meters.length then Loop.done(acc) - else meters(idx).available.map(v => Loop.continue(acc + v)) + else meters(idx).pendingWaiters.map(v => Loop.continue(acc + v)) } def run[A, S](v: => A < S)(using Frame) = - def loop(idx: Int = 0): A < (S & Async) = + def loop(idx: Int = 0): A < (S & Async & Abort[Closed]) = if idx == meters.length then v else meters(idx).run(loop(idx + 1)) loop() end run def tryRun[A, S](v: => A < S)(using Frame) = - def loop(idx: Int = 0): Maybe[A] < (S & IO) = + def loop(idx: Int = 0): Maybe[A] < (S & Async & Abort[Closed]) = if idx == meters.length then v.map(Maybe(_)) else meters(idx).tryRun(loop(idx + 1)).map { @@ -225,16 +221,123 @@ object Meter: def close(using Frame): Boolean < IO = Kyo.foreach(meters)(_.close).map(_.exists(identity)) + + def closed(using Frame): Boolean < IO = + Kyo.foreach(meters)(_.closed).map(_.exists(identity)) end new } - private def offer[A](n: Int, chan: Channel[A], v: A)(using Frame): Unit < IO = - Loop.indexed { idx => - if idx == n then Loop.done + abstract private class Base(permits: Int)(using initFrame: Frame, allow: AllowUnsafe) extends Meter: + + // MinValue => closed + // >= 0 => # of permits + // < 0 => # of waiters + val state = AtomicInt.Unsafe.init(permits) + val waiters = new MpmcUnboundedXaddArrayQueue[Promise.Unsafe[Closed, Unit]](8) + val closed = Promise.Unsafe.init[Closed, Nothing]() + + def dispatch[A, S](v: => A < S): A < (S & IO) + + final def run[A, S](v: => A < S)(using Frame) = + IO.Unsafe { + @tailrec def loop(): A < (S & Async & Abort[Closed]) = + val st = state.get() + if st == Int.MinValue then + // Meter is closed + closed.safe.get + else if state.cas(st, st - 1) then + if st > 0 then + // Permit available, dispatch immediately + dispatch(v) + else + // No permit available, add to waiters queue + val p = Promise.Unsafe.init[Closed, Unit]() + waiters.add(p) + p.safe.use(_ => dispatch(v)) + else + // CAS failed, retry + loop() + end if + end loop + loop() + } + end run + + final def tryRun[A, S](v: => A < S)(using Frame): Maybe[A] < (S & Async & Abort[Closed]) = + IO.Unsafe { + @tailrec def loop(): Maybe[A] < (S & Async & Abort[Closed]) = + val st = state.get() + if st == Int.MinValue then + // Meter is closed + closed.safe.get + else if st <= 0 then + // No permit available, return empty + Maybe.empty + else if state.cas(st, st - 1) then + // Permit available, dispatch + dispatch(v.map(Maybe(_))) + else + // CAS failed, retry + loop() + end if + end loop + loop() + } + end tryRun + + final def availablePermits(using Frame) = + IO.Unsafe { + state.get() match + case Int.MinValue => closed.safe.get + case st => Math.max(0, st) + } + + final def pendingWaiters(using Frame) = + IO.Unsafe { + state.get() match + case Int.MinValue => closed.safe.get + case st => Math.min(0, st).abs + } + + protected def onClose(): Unit = () + + final def close(using frame: Frame): Boolean < IO = + IO.Unsafe { + val st = state.getAndSet(Int.MinValue) + val ok = st != Int.MinValue // The meter wasn't already closed + if ok then + val fail = Result.fail(Closed("Semaphore is closed", initFrame, frame)) + // Complete the closed promise to fail new operations + closed.completeDiscard(fail) + // Drain the pending waiters + def drain(st: Int): Unit = + if st < 0 then + // Use pollWaiter to ensure all pending waiters + // as indicated by the state are drained + pollWaiter().completeDiscard(fail) + drain(st + 1) + drain(st) + onClose() + end if + ok + } + end close + + final def closed(using Frame) = IO(state.get() == Int.MinValue) + + final protected def release(): Unit = + // if the state was negative, indicating that there are waiters, + // release a waiter from the queue. + if state.incrementAndGet() <= 0 && !pollWaiter().complete(Result.unit) then + // the waiter is already completed due to interruption, + // try to release a waiter + release() + + @tailrec final private def pollWaiter(w: Promise.Unsafe[Closed, Unit] = waiters.poll()): Promise.Unsafe[Closed, Unit] = + if !isNull(w) then w else - chan.offer(v).map { - case true => Loop.continue - case false => Loop.done - } - } + // If no waiter is found, retry the poll operation + // This handles the race condition between state change and waiter queuing + pollWaiter() + end Base end Meter diff --git a/kyo-core/shared/src/main/scala/kyo/Queue.scala b/kyo-core/shared/src/main/scala/kyo/Queue.scala index ca211f882..7f2209243 100644 --- a/kyo-core/shared/src/main/scala/kyo/Queue.scala +++ b/kyo-core/shared/src/main/scala/kyo/Queue.scala @@ -1,6 +1,5 @@ package kyo -import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference import org.jctools.queues.* import scala.annotation.tailrec @@ -13,189 +12,332 @@ import scala.annotation.tailrec * @tparam A * the type of elements in the queue */ -class Queue[A] private[kyo] (initFrame: Frame, val unsafe: Queue.Unsafe[A]): +opaque type Queue[A] = Queue.Unsafe[A] - /** Returns the capacity of the queue. - * - * @return - * the capacity of the queue - */ - final def capacity(using Frame): Int = unsafe.capacity +object Queue: - /** Returns the current size of the queue. - * - * @return - * the current size of the queue - */ - final def size(using Frame): Int < IO = op(unsafe.size()) + extension [A](self: Queue[A]) + /** Returns the capacity of the queue. + * + * @return + * the capacity of the queue + */ + def capacity: Int = self.capacity - /** Checks if the queue is empty. - * - * @return - * true if the queue is empty, false otherwise - */ - final def empty(using Frame): Boolean < IO = op(unsafe.empty()) + /** Returns the current size of the queue. + * + * @return + * the current size of the queue + */ + def size(using Frame): Int < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.size())) - /** Checks if the queue is full. - * - * @return - * true if the queue is full, false otherwise - */ - final def full(using Frame): Boolean < IO = op(unsafe.full()) + /** Checks if the queue is empty. + * + * @return + * true if the queue is empty, false otherwise + */ + def empty(using Frame): Boolean < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.empty())) - /** Offers an element to the queue. - * - * @param v - * the element to offer - * @return - * true if the element was added, false if the queue is full or closed - */ - final def offer(v: A)(using Frame): Boolean < IO = IO.Unsafe(!unsafe.closed() && unsafe.offer(v)) + /** Checks if the queue is full. + * + * @return + * true if the queue is full, false otherwise + */ + def full(using Frame): Boolean < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.full())) - /** Polls an element from the queue. - * - * @return - * Maybe containing the polled element, or empty if the queue is empty - */ - final def poll(using Frame): Maybe[A] < IO = op(unsafe.poll()) + /** Offers an element to the queue. + * + * @param v + * the element to offer + * @return + * true if the element was added, false if the queue is full or closed + */ + def offer(v: A)(using Frame): Boolean < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.offer(v))) - /** Peeks at the first element in the queue without removing it. - * - * @return - * Maybe containing the first element, or empty if the queue is empty - */ - final def peek(using Frame): Maybe[A] < IO = op(unsafe.peek()) + /** Offers an element to the queue and discards the result + * + * @param v + * the element to offer + */ + def offerDiscard(v: A)(using Frame): Unit < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.offer(v).unit)) - /** Drains all elements from the queue. - * - * @return - * a sequence of all elements in the queue - */ - final def drain(using Frame): Seq[A] < IO = op(unsafe.drain()) + /** Polls an element from the queue. + * + * @return + * Maybe containing the polled element, or empty if the queue is empty + */ + def poll(using Frame): Maybe[A] < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.poll())) - /** Checks if the queue is closed. - * - * @return - * true if the queue is closed, false otherwise - */ - final def closed(using Frame): Boolean < IO = IO.Unsafe(unsafe.closed()) + /** Peeks at the first element in the queue without removing it. + * + * @return + * Maybe containing the first element, or empty if the queue is empty + */ + def peek(using Frame): Maybe[A] < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.peek())) - /** Closes the queue and returns any remaining elements. - * - * @return - * Maybe containing a sequence of remaining elements, or empty if already closed - */ - final def close(using Frame): Maybe[Seq[A]] < IO = IO.Unsafe(unsafe.close()) - - protected inline def op[A, S](inline v: AllowUnsafe ?=> A < (IO & S))(using frame: Frame): A < (IO & S) = - IO.Unsafe { - if unsafe.closed() then - throw Closed("Queue", initFrame, frame) - else - v - } -end Queue + /** Drains all elements from the queue. + * + * @return + * a sequence of all elements in the queue + */ + def drain(using Frame): Seq[A] < (IO & Abort[Closed]) = IO.Unsafe(Abort.get(self.drain())) -/** Companion object for Queue, containing factory methods and nested classes. - * - * This object provides various initialization methods for different types of queues, all based on JCTools' concurrent queue - * implementations. - */ -object Queue: + /** Closes the queue and returns any remaining elements. + * + * @return + * a sequence of remaining elements + */ + def close(using Frame): Maybe[Seq[A]] < IO = IO.Unsafe(self.close()) - /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ - abstract class Unsafe[A] - extends AtomicBoolean(false): - def capacity: Int - def size()(using AllowUnsafe): Int - def empty()(using AllowUnsafe): Boolean - def full()(using AllowUnsafe): Boolean - def offer(v: A)(using AllowUnsafe): Boolean - def poll()(using AllowUnsafe): Maybe[A] - def peek()(using AllowUnsafe): Maybe[A] - final def drain()(using AllowUnsafe): Seq[A] = - val b = Seq.newBuilder[A] - @tailrec def loop(): Unit = - val v = poll() - v match - case Absent => - case Present(v) => - b += v - loop() - end match - end loop - loop() - b.result() - end drain - - final def closed()(using AllowUnsafe): Boolean = - super.get() - - final def close()(using AllowUnsafe): Maybe[Seq[A]] = - super.compareAndSet(false, true) match - case false => - Maybe.empty - case true => - Maybe(drain()) - - final def safe(using frame: Frame): Queue[A] = Queue(frame, this) + /** Checks if the queue is closed. + * + * @return + * true if the queue is closed, false otherwise + */ + def closed(using Frame): Boolean < IO = IO.Unsafe(self.closed()) - end Unsafe + /** Returns the unsafe version of the queue. + * + * @return + * the unsafe version of the queue + */ + def unsafe: Unsafe[A] = self + end extension + + /** Initializes a new queue with the specified capacity and access pattern. The actual capacity will be rounded up to the next power of + * two. + * + * @param capacity + * the desired capacity of the queue. Note that this will be rounded up to the next power of two. + * @param access + * the access pattern (default is MPMC) + * @return + * a new Queue instance with a capacity that is the next power of two greater than or equal to the specified capacity + * + * @note + * The actual capacity will be rounded up to the next power of two. + * @warning + * The actual capacity may be larger than the specified capacity due to rounding. + */ + def init[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using Frame): Queue[A] < IO = + IO.Unsafe(Unsafe.init(capacity, access)) /** An unbounded queue that can grow indefinitely. * * @tparam A * the type of elements in the queue */ - class Unbounded[A] private[kyo] (initFrame: Frame, unsafe: Queue.Unsafe[A]) extends Queue[A](initFrame, unsafe): - /** Adds an element to the unbounded queue. + opaque type Unbounded[A] <: Queue[A] = Queue[A] + + object Unbounded: + extension [A](self: Unbounded[A]) + /** Adds an element to the unbounded queue. + * + * @param value + * the element to add + */ + def add(value: A)(using Frame): Unit < IO = IO.Unsafe(Unsafe.add(self)(value)) + + def unsafe: Unsafe[A] = self + end extension + + /** Initializes a new unbounded queue with the specified access pattern and chunk size. * - * @param v - * the element to add + * @param access + * the access pattern (default is MPMC) + * @param chunkSize + * the chunk size for internal array allocation (default is 8) + * @return + * a new Unbounded Queue instance + */ + def init[A](access: Access = Access.MultiProducerMultiConsumer, chunkSize: Int = 8)(using Frame): Unbounded[A] < IO = + IO.Unsafe(Unsafe.init(access, chunkSize)) + + /** Initializes a new dropping queue with the specified capacity and access pattern. + * + * @param capacity + * the capacity of the queue. Note that this will be rounded up to the next power of two. + * @param access + * the access pattern (default is MPMC) + * @return + * a new Unbounded Queue instance that drops elements when full + * + * @note + * The actual capacity will be rounded up to the next power of two. + * @warning + * The actual capacity may be larger than the specified capacity due to rounding. + */ + def initDropping[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using Frame): Unbounded[A] < IO = + IO.Unsafe(Unsafe.initDropping(capacity, access)) + + /** Initializes a new sliding queue with the specified capacity and access pattern. + * + * @param capacity + * the capacity of the queue. Note that this will be rounded up to the next power of two. + * @param access + * the access pattern (default is MPMC) + * @return + * a new Unbounded Queue instance that slides elements when full + * + * @note + * The actual capacity will be rounded up to the next power of two. + * @warning + * The actual capacity may be larger than the specified capacity due to rounding. */ - final def add[S](v: A < S)(using Frame): Unit < (IO & S) = - op(v.map(offer).unit) + def initSliding[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using Frame): Unbounded[A] < IO = + IO.Unsafe(Unsafe.initSliding(capacity, access)) + + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + opaque type Unsafe[A] <: Queue.Unsafe[A] = Queue[A] + + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + object Unsafe: + extension [A](self: Unsafe[A]) + def add(value: A)(using AllowUnsafe, Frame): Unit = discard(self.offer(value)) + def safe: Unbounded[A] = self + + def init[A](access: Access = Access.MultiProducerMultiConsumer, chunkSize: Int = 8)( + using + Frame, + AllowUnsafe + ): Unsafe[A] = + access match + case Access.MultiProducerMultiConsumer => + Queue.Unsafe.fromJava(new MpmcUnboundedXaddArrayQueue[A](chunkSize)) + case Access.MultiProducerSingleConsumer => + Queue.Unsafe.fromJava(new MpscUnboundedArrayQueue[A](chunkSize)) + case Access.SingleProducerMultiConsumer => + Queue.Unsafe.fromJava(new MpmcUnboundedXaddArrayQueue[A](chunkSize)) + case Access.SingleProducerSingleConsumer => + Queue.Unsafe.fromJava(new SpscUnboundedArrayQueue[A](chunkSize)) + + def initDropping[A](_capacity: Int, access: Access = Access.MultiProducerMultiConsumer)( + using + frame: Frame, + allow: AllowUnsafe + ): Unsafe[A] = + new Unsafe[A]: + val underlying = Queue.Unsafe.init[A](_capacity, access) + def capacity = _capacity + def size()(using AllowUnsafe) = underlying.size() + def empty()(using AllowUnsafe) = underlying.empty() + def full()(using AllowUnsafe) = underlying.full().map(_ => false) + def offer(v: A)(using AllowUnsafe) = underlying.offer(v).map(_ => true) + def poll()(using AllowUnsafe) = underlying.poll() + def peek()(using AllowUnsafe) = underlying.peek() + def drain()(using AllowUnsafe) = underlying.drain() + def close()(using Frame, AllowUnsafe) = underlying.close() + def closed()(using AllowUnsafe): Boolean = underlying.closed() + end new + end initDropping + + def initSliding[A](_capacity: Int, access: Access = Access.MultiProducerMultiConsumer)( + using + frame: Frame, + allow: AllowUnsafe + ): Unsafe[A] = + new Unsafe[A]: + val underlying = Queue.Unsafe.init[A](_capacity, access) + def capacity = _capacity + def size()(using AllowUnsafe) = underlying.size() + def empty()(using AllowUnsafe) = underlying.empty() + def full()(using AllowUnsafe) = underlying.full().map(_ => false) + def offer(v: A)(using AllowUnsafe) = + @tailrec def loop(v: A): Result[Closed, Boolean] = + underlying.offer(v) match + case Result.Success(false) => + discard(underlying.poll()) + loop(v) + case result => + result + end loop + loop(v) + end offer + def poll()(using AllowUnsafe) = underlying.poll() + def peek()(using AllowUnsafe) = underlying.peek() + def drain()(using AllowUnsafe) = underlying.drain() + def close()(using Frame, AllowUnsafe) = underlying.close() + def closed()(using AllowUnsafe): Boolean = underlying.closed() + end new + end initSliding + end Unsafe end Unbounded - /** Initializes a new queue with the specified capacity and access pattern. - * - * @param capacity - * the capacity of the queue - * @param access - * the access pattern (default is MPMC) - * @return - * a new Queue instance - */ - def init[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using frame: Frame): Queue[A] < IO = - IO { + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + abstract class Unsafe[A]: + def capacity: Int + def size()(using AllowUnsafe): Result[Closed, Int] + def empty()(using AllowUnsafe): Result[Closed, Boolean] + def full()(using AllowUnsafe): Result[Closed, Boolean] + def offer(v: A)(using AllowUnsafe): Result[Closed, Boolean] + def poll()(using AllowUnsafe): Result[Closed, Maybe[A]] + def peek()(using AllowUnsafe): Result[Closed, Maybe[A]] + def drain()(using AllowUnsafe): Result[Closed, Seq[A]] + def close()(using Frame, AllowUnsafe): Maybe[Seq[A]] + def closed()(using AllowUnsafe): Boolean + final def safe: Queue[A] = this + end Unsafe + + /** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */ + object Unsafe: + + abstract private class Closeable[A](initFrame: Frame) extends Unsafe[A]: + import AllowUnsafe.embrace.danger + final protected val _closed = AtomicRef.Unsafe.init(Maybe.empty[Result.Error[Closed]]) + + final def close()(using frame: Frame, allow: AllowUnsafe) = + val fail = Result.Fail(Closed("Queue", initFrame, frame)) + Maybe.when(_closed.cas(Maybe.empty, Maybe(fail)))(_drain()) + end close + + final def closed()(using AllowUnsafe) = _closed.get().isDefined + + final def drain()(using AllowUnsafe): Result[Closed, Seq[A]] = op(_drain()) + + protected def _drain(): Seq[A] + + protected inline def op[A](inline f: => A): Result[Closed, A] = + _closed.get().getOrElse(Result(f)) + + protected inline def offerOp[A](inline f: => Boolean, inline raceRepair: => Boolean): Result[Closed, Boolean] = + _closed.get().getOrElse { + val result = f + if result && _closed.get().isDefined then + Result(raceRepair) + else + Result(result) + end if + } + end Closeable + + def init[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using + initFrame: Frame, + allow: AllowUnsafe + ): Unsafe[A] = capacity match - case c if (c <= 0) => - new Queue( - frame, - new Unsafe[A]: - def capacity = 0 - def size()(using AllowUnsafe) = 0 - def empty()(using AllowUnsafe) = true - def full()(using AllowUnsafe) = true - def offer(v: A)(using AllowUnsafe) = false - def poll()(using AllowUnsafe) = Maybe.empty - def peek()(using AllowUnsafe) = Maybe.empty - ) + case _ if capacity <= 0 => + new Closeable[A](initFrame): + def capacity = 0 + def size()(using AllowUnsafe) = op(0) + def empty()(using AllowUnsafe) = op(true) + def full()(using AllowUnsafe) = op(true) + def offer(v: A)(using AllowUnsafe) = op(false) + def poll()(using AllowUnsafe) = op(Maybe.empty) + def peek()(using AllowUnsafe) = op(Maybe.empty) + def _drain() = Seq.empty case 1 => - new Queue( - frame, - new Unsafe[A]: - val state = new AtomicReference[A] - def capacity = 1 - def size()(using AllowUnsafe) = if isNull(state.get()) then 0 else 1 - def empty()(using AllowUnsafe) = isNull(state.get()) - def full()(using AllowUnsafe) = !isNull(state.get()) - def offer(v: A)(using AllowUnsafe) = state.compareAndSet(null.asInstanceOf[A], v) - def poll()(using AllowUnsafe) = Maybe(state.getAndSet(null.asInstanceOf[A])) - def peek()(using AllowUnsafe) = Maybe(state.get()) - ) + new Closeable[A](initFrame): + private val state = AtomicRef.Unsafe.init(Maybe.empty[A]) + def capacity = 1 + def empty()(using AllowUnsafe) = op(state.get().isEmpty) + def size()(using AllowUnsafe) = op(if state.get().isEmpty then 0 else 1) + def full()(using AllowUnsafe) = op(state.get().isDefined) + def offer(v: A)(using AllowUnsafe) = offerOp(state.cas(Maybe.empty, Maybe(v)), !state.cas(Maybe(v), Maybe.empty)) + def poll()(using AllowUnsafe) = op(state.getAndSet(Maybe.empty)) + def peek()(using AllowUnsafe) = op(state.get()) + def _drain() = state.getAndSet(Maybe.empty).toList case Int.MaxValue => - initUnbounded(access) + Unbounded.Unsafe.init(access).safe case _ => access match case Access.MultiProducerMultiConsumer => @@ -210,118 +352,38 @@ object Queue: else // Spsc queue doesn't support capacity < 4 fromJava(new SpmcArrayQueue[A](capacity), capacity) - } - /** Initializes a new unbounded queue with the specified access pattern and chunk size. - * - * @param access - * the access pattern (default is MPMC) - * @param chunkSize - * the chunk size for internal array allocation (default is 8) - * @return - * a new Unbounded Queue instance - */ - def initUnbounded[A](access: Access = Access.MultiProducerMultiConsumer, chunkSize: Int = 8)(using Frame): Unbounded[A] < IO = - IO { - access match - case Access.MultiProducerMultiConsumer => - fromJava(new MpmcUnboundedXaddArrayQueue[A](chunkSize)) - case Access.MultiProducerSingleConsumer => - fromJava(new MpscUnboundedArrayQueue[A](chunkSize)) - case Access.SingleProducerMultiConsumer => - fromJava(new MpmcUnboundedXaddArrayQueue[A](chunkSize)) - case Access.SingleProducerSingleConsumer => - fromJava(new SpscUnboundedArrayQueue[A](chunkSize)) - } - - /** Initializes a new dropping queue with the specified capacity and access pattern. - * - * @param capacity - * the capacity of the queue - * @param access - * the access pattern (default is MPMC) - * @return - * a new Unbounded Queue instance that drops elements when full - */ - def initDropping[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using frame: Frame): Unbounded[A] < IO = - init[A](capacity, access).map { q => - val u = q.unsafe - val c = capacity - new Unbounded( - frame, - new Unsafe[A]: - def capacity = c - def size()(using AllowUnsafe) = u.size() - def empty()(using AllowUnsafe) = u.empty() - def full()(using AllowUnsafe) = false - def offer(v: A)(using AllowUnsafe) = - discard(u.offer(v)) - true - def poll()(using AllowUnsafe) = u.poll() - def peek()(using AllowUnsafe) = u.peek() - ) - } - - /** Initializes a new sliding queue with the specified capacity and access pattern. - * - * @param capacity - * the capacity of the queue - * @param access - * the access pattern (default is MPMC) - * @return - * a new Unbounded Queue instance that slides elements when full - */ - def initSliding[A](capacity: Int, access: Access = Access.MultiProducerMultiConsumer)(using frame: Frame): Unbounded[A] < IO = - init[A](capacity, access).map { q => - val u = q.unsafe - val c = capacity - new Unbounded( - frame, - new Unsafe[A]: - def capacity = c - def size()(using AllowUnsafe) = u.size() - def empty()(using AllowUnsafe) = u.empty() - def full()(using AllowUnsafe) = false - def offer(v: A)(using AllowUnsafe) = - @tailrec def loop(v: A): Unit = - val u = q.unsafe - if u.offer(v) then () - else - discard(u.poll()) - loop(v) - end if - end loop - loop(v) - true - end offer - def poll()(using AllowUnsafe) = u.poll() - def peek()(using AllowUnsafe) = u.peek() - ) - } - - private def fromJava[A](q: java.util.Queue[A])(using frame: Frame): Unbounded[A] = - new Unbounded( - frame, - new Unsafe[A]: - def capacity = Int.MaxValue - def size()(using AllowUnsafe) = q.size - def empty()(using AllowUnsafe) = q.isEmpty() - def full()(using AllowUnsafe) = false - def offer(v: A)(using AllowUnsafe) = q.offer(v) - def poll()(using AllowUnsafe) = Maybe(q.poll) - def peek()(using AllowUnsafe) = Maybe(q.peek) - ) - - private def fromJava[A](q: java.util.Queue[A], _capacity: Int)(using frame: Frame): Queue[A] = - new Queue( - frame, - new Unsafe[A]: - def capacity = _capacity - def size()(using AllowUnsafe) = q.size - def empty()(using AllowUnsafe) = q.isEmpty() - def full()(using AllowUnsafe) = q.size >= _capacity - def offer(v: A)(using AllowUnsafe) = q.offer(v) - def poll()(using AllowUnsafe) = Maybe(q.poll) - def peek()(using AllowUnsafe) = Maybe(q.peek) - ) + def fromJava[A](q: java.util.Queue[A], _capacity: Int = Int.MaxValue)(using initFrame: Frame, allow: AllowUnsafe): Unsafe[A] = + new Closeable[A](initFrame): + def capacity = _capacity + def size()(using AllowUnsafe) = op(q.size()) + def empty()(using AllowUnsafe) = op(q.isEmpty()) + def full()(using AllowUnsafe) = op(q.size() >= _capacity) + def offer(v: A)(using AllowUnsafe) = + offerOp( + q.offer(v), + try !q.remove(v) + catch + case _: UnsupportedOperationException => + // TODO the race repair should use '!q.remove(v)' but JCTools doesn't support the operation. + // In rare cases, items may be left in the queue permanently after closing due to this limitation. + // The item will only be removed when the queue object itself is garbage collected. + !q.contains(v) + ) + def poll()(using AllowUnsafe) = op(Maybe(q.poll())) + def peek()(using AllowUnsafe) = op(Maybe(q.peek())) + def _drain() = + val b = Seq.newBuilder[A] + @tailrec def loop(): Unit = + val value = q.poll() + if !isNull(value) then + b.addOne(value) + loop() + end loop + loop() + b.result() + end _drain + + end Unsafe + end Queue diff --git a/kyo-core/shared/src/main/scala/kyo/Resource.scala b/kyo-core/shared/src/main/scala/kyo/Resource.scala index 0117ad568..e9a768ac1 100644 --- a/kyo-core/shared/src/main/scala/kyo/Resource.scala +++ b/kyo-core/shared/src/main/scala/kyo/Resource.scala @@ -28,9 +28,9 @@ object Resource: */ def ensure(v: => Unit < Async)(using frame: Frame): Unit < (Resource & IO) = ContextEffect.suspendMap(Tag[Resource]) { finalizer => - finalizer.queue.offer(IO(v)).map { - case true => () - case false => + Abort.run(finalizer.queue.offer(IO(v))).map { + case Result.Success(_) => () + case _ => throw new Closed( "Resource finalizer queue already closed. This may happen if " + "a background fiber escapes the scope of a 'Resource.run' call.", @@ -78,7 +78,7 @@ object Resource: * The result of the effect wrapped in Async and S effects. */ def run[A, S](v: A < (Resource & S))(using frame: Frame): A < (Async & S) = - Queue.initUnbounded[Unit < Async](Access.MultiProducerSingleConsumer).map { q => + Queue.Unbounded.init[Unit < Async](Access.MultiProducerSingleConsumer).map { q => Promise.init[Nothing, Unit].map { p => val finalizer = Finalizer(frame, q) def close: Unit < IO = diff --git a/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala b/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala index 1bd3cfb0d..c3c8ec65b 100644 --- a/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/ChannelTest.scala @@ -95,8 +95,8 @@ class ChannelTest extends Test: for c <- Channel.init[Int](2) r <- c.close - t <- c.offer(1) - yield assert(r == Maybe(Seq()) && !t) + t <- Abort.run(c.offer(1)) + yield assert(r == Maybe(Seq()) && t.isFail) } "non-empty" in runJVM { for @@ -104,7 +104,7 @@ class ChannelTest extends Test: _ <- c.put(1) _ <- c.put(2) r <- c.close - t <- Abort.run[Throwable](c.empty) + t <- Abort.run(c.empty) yield assert(r == Maybe(Seq(1, 2)) && t.isFail) } "pending take" in runJVM { @@ -113,8 +113,8 @@ class ChannelTest extends Test: f <- c.takeFiber r <- c.close d <- f.getResult - t <- Abort.run[Throwable](c.full) - yield assert(r == Maybe(Seq()) && d.isPanic && t.isFail) + t <- Abort.run(c.full) + yield assert(r == Maybe(Seq()) && d.isFail && t.isFail) } "pending put" in runJVM { for @@ -124,8 +124,8 @@ class ChannelTest extends Test: f <- c.putFiber(3) r <- c.close d <- f.getResult - _ <- c.offerDiscard(1) - yield assert(r == Maybe(Seq(1, 2)) && d.isPanic) + e <- Abort.run(c.offer(1)) + yield assert(r == Maybe(Seq(1, 2)) && d.isFail && e.isFail) } "no buffer w/ pending put" in runJVM { for @@ -133,8 +133,8 @@ class ChannelTest extends Test: f <- c.putFiber(1) r <- c.close d <- f.getResult - t <- c.poll - yield assert(r == Maybe(Seq()) && d.isPanic && t.isEmpty) + t <- Abort.run(c.poll) + yield assert(r == Maybe(Seq()) && d.isFail && t.isFail) } "no buffer w/ pending take" in runJVM { for @@ -143,7 +143,7 @@ class ChannelTest extends Test: r <- c.close d <- f.getResult t <- Abort.run[Throwable](c.put(1)) - yield assert(r == Maybe(Seq()) && d.isPanic && t.isFail) + yield assert(r == Maybe(Seq()) && d.isFail && t.isFail) } } "no buffer" in runJVM { @@ -178,4 +178,164 @@ class ChannelTest extends Test: yield assert(b) } } + + "concurrency" - { + + val repeats = 100 + + "offer and close" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + channel <- Channel.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(channel.offer(i))))) + ) + closeFiber <- Async.run(latch.await.andThen(channel.close)) + _ <- latch.release + offered <- offerFiber.get + backlog <- closeFiber.get + closedChannel <- channel.close + drained <- Abort.run(channel.drain) + isClosed <- channel.closed + yield + assert(backlog.isDefined) + assert(offered.count(_.contains(true)) == backlog.get.size) + assert(closedChannel.isEmpty) + assert(drained.isFail) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "offer and poll" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + channel <- Channel.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(channel.offer(i))))) + ) + pollFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => Abort.run(channel.poll)))) + ) + _ <- latch.release + offered <- offerFiber.get + polled <- pollFiber.get + channelSize <- channel.size + yield assert(offered.count(_.contains(true)) == polled.count(_.toMaybe.flatten.isDefined) + channelSize)) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "put and take" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + channel <- Channel.init[Int](size) + latch <- Latch.init(1) + putFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(channel.put(i))))) + ) + takeFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => Abort.run(channel.take)))) + ) + _ <- latch.release + puts <- putFiber.get + takes <- takeFiber.get + yield assert(puts.count(_.isSuccess) == takes.count(_.isSuccess) && takes.flatMap(_.toMaybe.toList).toSet == (1 to 100).toSet)) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "offer to full channel during close" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + channel <- Channel.init[Int](size) + _ <- Kyo.foreach(1 to size)(i => channel.offer(i)) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(channel.offer(i))))) + ) + closeFiber <- Async.run(latch.await.andThen(channel.close)) + _ <- latch.release + offered <- offerFiber.get + backlog <- closeFiber.get + isClosed <- channel.closed + yield + assert(backlog.isDefined) + assert(offered.count(_.contains(true)) == backlog.get.size - size) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "concurrent close attempts" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + channel <- Channel.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(channel.offer(i))))) + ) + closeFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => channel.close))) + ) + _ <- latch.release + offered <- offerFiber.get + backlog <- closeFiber.get + isClosed <- channel.closed + yield + assert(backlog.count(_.isDefined) == 1) + assert(backlog.flatMap(_.toList.flatten).size == offered.count(_.contains(true))) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "offer, poll, put, take, and close" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + channel <- Channel.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 50).map(i => Abort.run(channel.offer(i))))) + ) + pollFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 50).map(_ => Abort.run(channel.poll)))) + ) + putFiber <- Async.run( + latch.await.andThen(Async.parallel((51 to 100).map(i => Abort.run(channel.put(i))))) + ) + takeFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 50).map(_ => Abort.run(channel.take)))) + ) + closeFiber <- Async.run(latch.await.andThen(channel.close)) + _ <- latch.release + offered <- offerFiber.get + polled <- pollFiber.get + puts <- putFiber.get + takes <- takeFiber.get + backlog <- closeFiber.get + isClosed <- channel.closed + yield + val totalOffered = offered.count(_.contains(true)) + puts.count(_.isSuccess) + val totalTaken = polled.count(_.toMaybe.flatten.isDefined) + takes.count(_.isSuccess) + assert(backlog.isDefined) + assert(totalOffered - totalTaken == backlog.get.size) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + } + end ChannelTest diff --git a/kyo-core/shared/src/test/scala/kyo/HubTest.scala b/kyo-core/shared/src/test/scala/kyo/HubTest.scala index ab7617843..47693707c 100644 --- a/kyo-core/shared/src/test/scala/kyo/HubTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/HubTest.scala @@ -57,12 +57,12 @@ class HubTest extends Test: _ <- untilTrue(h.empty) // wait transfer l <- h.listen c1 <- h.close - v1 <- Abort.run[Throwable](h.listen) - v2 <- h.offer(2) - v3 <- l.poll - c2 <- l.close + v1 <- Abort.run(h.listen) + v2 <- Abort.run(h.offer(2)) + v3 <- Abort.run(l.poll) + v4 <- l.close yield assert( - b && c1 == Maybe(Seq()) && v1.isFail && !v2 && v3.isEmpty && c2.isEmpty + b && c1 == Maybe(Seq()) && v1.isFail && v2.isFail && v3.isFail && v4.isEmpty ) } "close listener w/ buffer" in runJVM { @@ -107,22 +107,22 @@ class HubTest extends Test: } "listener removal" in runJVM { for - h <- Hub.init[Int](2) - l <- h.listen - _ <- h.offer(1) - _ <- untilTrue(h.empty) - c <- l.close - _ <- h.offer(2) - v <- l.poll - yield assert(c == Maybe(Seq()) && v.isEmpty) + h <- Hub.init[Int](2) + l <- h.listen + _ <- h.offer(1) + _ <- untilTrue(h.empty) + c <- l.close + v1 <- h.offer(2) + v2 <- Abort.run(l.poll) + yield assert(c == Maybe(Seq()) && v1 && v2.isFail) } "hub closure with pending offers" in runJVM { for h <- Hub.init[Int](2) _ <- h.offer(1) _ <- h.close - v <- h.offer(2) - yield assert(!v) + v <- Abort.run(h.offer(2)) + yield assert(v.isFail) } "create listener on empty hub" in runJVM { for diff --git a/kyo-core/shared/src/test/scala/kyo/MeterTest.scala b/kyo-core/shared/src/test/scala/kyo/MeterTest.scala index 000d1afed..5a0617d22 100644 --- a/kyo-core/shared/src/test/scala/kyo/MeterTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/MeterTest.scala @@ -17,33 +17,37 @@ class MeterTest extends Test: b1 <- Promise.init[Nothing, Unit] f1 <- Async.run(t.run(b1.complete(Result.unit).map(_ => p.block(Duration.Infinity)))) _ <- b1.get - a1 <- t.isAvailable + a1 <- t.availablePermits + w1 <- t.pendingWaiters b2 <- Promise.init[Nothing, Unit] f2 <- Async.run(b2.complete(Result.unit).map(_ => t.run(2))) _ <- b2.get - a2 <- t.isAvailable + a2 <- t.availablePermits + w2 <- t.pendingWaiters d1 <- f1.done d2 <- f2.done _ <- p.complete(Result.success(1)) v1 <- f1.get v2 <- f2.get - a3 <- t.isAvailable - yield assert(!a1 && !d1 && !d2 && !a2 && v1 == Result.success(1) && v2 == 2 && a3) + a3 <- t.availablePermits + w3 <- t.pendingWaiters + yield assert(a1 == 0 && w1 == 0 && !d1 && !d2 && a2 == 0 && w2 == 1 && v1 == Result.success(1) && v2 == 2 && a3 == 1 && w3 == 0) } "tryRun" in runJVM { for - sem <- Meter.initSemaphore(1) + sem <- Meter.initMutex p <- Promise.init[Nothing, Int] b1 <- Promise.init[Nothing, Unit] f1 <- Async.run(sem.tryRun(b1.complete(Result.unit).map(_ => p.block(Duration.Infinity)))) _ <- b1.get - a1 <- sem.isAvailable + a1 <- sem.availablePermits + w1 <- sem.pendingWaiters b1 <- sem.tryRun(2) b2 <- f1.done _ <- p.complete(Result.success(1)) v1 <- f1.get - yield assert(!a1 && b1.isEmpty && !b2 && v1.contains(Result.success(1))) + yield assert(a1 == 0 && w1 == 0 && b1.isEmpty && !b2 && v1.contains(Result.success(1))) } } @@ -66,18 +70,24 @@ class MeterTest extends Test: b2 <- Promise.init[Nothing, Unit] f2 <- Async.run(t.run(b2.complete(Result.unit).map(_ => p.block(Duration.Infinity)))) _ <- b2.get - a1 <- t.isAvailable + a1 <- t.availablePermits + w1 <- t.pendingWaiters b3 <- Promise.init[Nothing, Unit] - f2 <- Async.run(b3.complete(Result.unit).map(_ => t.run(2))) + f3 <- Async.run(b3.complete(Result.unit).map(_ => t.run(2))) _ <- b3.get - a2 <- t.isAvailable + a2 <- t.availablePermits + w2 <- t.pendingWaiters d1 <- f1.done d2 <- f2.done + d3 <- f3.done _ <- p.complete(Result.success(1)) v1 <- f1.get v2 <- f2.get - a3 <- t.isAvailable - yield assert(!a1 && !d1 && !d2 && !a2 && v1 == Result.success(1) && v2 == 2 && a3) + v3 <- f3.get + a3 <- t.availablePermits + w3 <- t.pendingWaiters + yield assert(a1 == 0 && w1 == 0 && !d1 && !d2 && !d3 && a2 == 0 && w2 == 1 && + v1 == Result.success(1) && v2 == Result.success(1) && v3 == 2 && a3 == 2 && w3 == 0) } "tryRun" in runJVM { @@ -87,21 +97,100 @@ class MeterTest extends Test: b1 <- Promise.init[Nothing, Unit] f1 <- Async.run(sem.tryRun(b1.complete(Result.unit).map(_ => p.block(Duration.Infinity)))) _ <- b1.get + a1 <- sem.availablePermits + w1 <- sem.pendingWaiters b2 <- Promise.init[Nothing, Unit] f2 <- Async.run(sem.tryRun(b2.complete(Result.unit).map(_ => p.block(Duration.Infinity)))) _ <- b2.get - a1 <- sem.isAvailable + a2 <- sem.availablePermits + w2 <- sem.pendingWaiters b3 <- sem.tryRun(2) b4 <- f1.done b5 <- f2.done _ <- p.complete(Result.success(1)) v1 <- f1.get v2 <- f2.get - yield assert(!a1 && b3.isEmpty && !b4 && !b5 && v1.contains(Result.success(1)) && v2.contains(Result.success(1))) + yield assert(a1 == 1 && w1 == 0 && b3.isEmpty && !b4 && !b5 && v1.contains(Result.success(1)) && v2.contains(Result.success(1))) + } + + "concurrency" - { + + val repeats = 100 + + "run" in run { + (for + size <- Choice.get(Seq(1, 2, 3, 50, 100)) + meter <- Meter.initSemaphore(size) + counter <- AtomicInt.init(0) + results <- + Async.parallel((1 to 100).map(_ => + Abort.run(meter.run(counter.incrementAndGet)) + )) + count <- counter.get + permits <- meter.availablePermits + yield + assert(results.count(_.isFail) == 0) + assert(count == 100) + assert(permits == size) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "close" in run { + (for + size <- Choice.get(Seq(1, 2, 3, 50, 100)) + meter <- Meter.initSemaphore(size) + latch <- Latch.init(1) + counter <- AtomicInt.init(0) + runFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => + Abort.run(meter.run(counter.incrementAndGet)) + ))) + ) + closeFiber <- Async.run(latch.await.andThen(meter.close)) + _ <- latch.release + closed <- closeFiber.get + completed <- runFiber.get + count <- counter.get + available <- Abort.run(meter.availablePermits) + yield + assert(closed) + assert(completed.count(_.isSuccess) <= 100) + assert(count <= 100) + assert(available.isFail) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + + "with interruptions" in run { + (for + size <- Choice.get(Seq(1, 2, 3, 50, 100)) + meter <- Meter.initSemaphore(size) + latch <- Latch.init(1) + counter <- AtomicInt.init(0) + runFibers <- Kyo.foreach(1 to 100)(_ => + Async.run(latch.await.andThen(meter.run(counter.incrementAndGet))) + ) + interruptFiber <- Async.run(latch.await.andThen(Async.parallel( + runFibers.take(50).map(_.interrupt(panic)) + ))) + _ <- latch.release + interrupted <- interruptFiber.get + completed <- Kyo.foreach(runFibers)(_.getResult) + count <- counter.get + yield assert(interrupted.count(identity) + completed.count(_.isSuccess) == 100)) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } } } - def loop(meter: Meter, counter: AtomicInt): Unit < Async = + def loop(meter: Meter, counter: AtomicInt): Unit < (Async & Abort[Closed]) = meter.run(counter.incrementAndGet).map(_ => loop(meter, counter)) val panic = Result.Panic(new Exception) @@ -159,11 +248,12 @@ class MeterTest extends Test: counter <- AtomicInt.init(0) f1 <- Async.run(loop(meter, counter)) _ <- Async.sleep(5.millis) - _ <- untilTrue(meter.isAvailable.map(!_)) + _ <- untilTrue(meter.availablePermits.map(_ == 0)) _ <- Async.sleep(5.millis) r <- meter.tryRun(()) _ <- f1.interrupt(panic) yield assert(r.isEmpty) } } + end MeterTest diff --git a/kyo-core/shared/src/test/scala/kyo/QueueTest.scala b/kyo-core/shared/src/test/scala/kyo/QueueTest.scala index 1f9382551..8a3bcfcac 100644 --- a/kyo-core/shared/src/test/scala/kyo/QueueTest.scala +++ b/kyo-core/shared/src/test/scala/kyo/QueueTest.scala @@ -1,5 +1,7 @@ package kyo +import kyo.kernel.Platform + class QueueTest extends Test: val access = Access.values.toList @@ -61,20 +63,20 @@ class QueueTest extends Test: q <- Queue.init[Int](2) b <- q.offer(1) c1 <- q.close - v1 <- Abort.run[Throwable](q.size) - v2 <- Abort.run[Throwable](q.empty) - v3 <- Abort.run[Throwable](q.full) - v4 <- q.offer(2) - v5 <- Abort.run[Throwable](q.poll) - v6 <- Abort.run[Throwable](q.peek) - v7 <- Abort.run[Throwable](q.drain) + v1 <- Abort.run(q.size) + v2 <- Abort.run(q.empty) + v3 <- Abort.run(q.full) + v4 <- Abort.run(q.offer(2)) + v5 <- Abort.run(q.poll) + v6 <- Abort.run(q.peek) + v7 <- Abort.run(q.drain) c2 <- q.close yield assert( b && c1 == Maybe(Seq(1)) && v1.isFail && v2.isFail && v3.isFail && - !v4 && + v4.isFail && v5.isFail && v6.isFail && v7.isFail && @@ -96,27 +98,27 @@ class QueueTest extends Test: access.toString() - { "isEmpty" in run { for - q <- Queue.initUnbounded[Int](access) + q <- Queue.Unbounded.init[Int](access) b <- q.empty yield assert(b) } "offer and poll" in run { for - q <- Queue.initUnbounded[Int](access) + q <- Queue.Unbounded.init[Int](access) b <- q.offer(1) v <- q.poll yield assert(b && v == Maybe(1)) } "peek" in run { for - q <- Queue.initUnbounded[Int](access) + q <- Queue.Unbounded.init[Int](access) _ <- q.offer(1) v <- q.peek yield assert(v == Maybe(1)) } "add and poll" in run { for - q <- Queue.initUnbounded[Int](access) + q <- Queue.Unbounded.init[Int](access) _ <- q.add(1) v <- q.poll yield assert(v == Maybe(1)) @@ -129,7 +131,7 @@ class QueueTest extends Test: access.foreach { access => access.toString() in run { for - q <- Queue.initDropping[Int](2) + q <- Queue.Unbounded.initDropping[Int](2) _ <- q.add(1) _ <- q.add(2) _ <- q.add(3) @@ -145,7 +147,7 @@ class QueueTest extends Test: access.foreach { access => access.toString() in run { for - q <- Queue.initSliding[Int](2) + q <- Queue.Unbounded.initSliding[Int](2) _ <- q.add(1) _ <- q.add(2) _ <- q.add(3) @@ -160,37 +162,37 @@ class QueueTest extends Test: "unsafe" - { import AllowUnsafe.embrace.danger - def withQueue[A](f: TestUnsafeQueue[Int] => A): A = - f(TestUnsafeQueue[Int](2)) + def withQueue[A](f: Queue.Unsafe[Int] => A): A = + f(Queue.Unsafe.init[Int](2)) "should offer and poll correctly" in withQueue { testUnsafe => - assert(testUnsafe.offer(1)) - assert(testUnsafe.poll() == Maybe(1)) + assert(testUnsafe.offer(1).contains(true)) + assert(testUnsafe.poll().contains(Maybe(1))) } "should peek correctly" in withQueue { testUnsafe => testUnsafe.offer(2) - assert(testUnsafe.peek() == Maybe(2)) + assert(testUnsafe.peek().contains(Maybe(2))) } "should report empty correctly" in withQueue { testUnsafe => - assert(testUnsafe.empty()) + assert(testUnsafe.empty().contains(true)) testUnsafe.offer(3) - assert(!testUnsafe.empty()) + assert(testUnsafe.empty().contains(false)) } "should report size correctly" in withQueue { testUnsafe => - assert(testUnsafe.size() == 0) + assert(testUnsafe.size().contains(0)) testUnsafe.offer(3) - assert(testUnsafe.size() == 1) + assert(testUnsafe.size().contains(1)) } "should drain correctly" in withQueue { testUnsafe => testUnsafe.offer(3) testUnsafe.offer(4) val drained = testUnsafe.drain() - assert(drained == Seq(3, 4)) - assert(testUnsafe.empty()) + assert(drained == Result.success(Seq(3, 4))) + assert(testUnsafe.empty().contains(true)) } "should close correctly" in withQueue { testUnsafe => @@ -201,39 +203,161 @@ class QueueTest extends Test: } } - case class TestUnsafeQueue[A](capacity: Int) extends Queue.Unsafe[A]: - private var elements = scala.collection.mutable.Queue[A]() - private var closed = false + "concurrency" - { - def offer(a: A)(using AllowUnsafe): Boolean = - if closed then throw new IllegalStateException("Queue is closed") - else if elements.size >= capacity then false - else - elements.enqueue(a) - true + val repeats = 100 - def poll()(using AllowUnsafe): Maybe[A] = - if closed then throw new IllegalStateException("Queue is closed") - else if elements.isEmpty then Maybe.empty - else Maybe(elements.dequeue()) + "offer and close" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + queue <- Queue.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(queue.offer(i))))) + ) + closeFiber <- Async.run(latch.await.andThen(queue.close)) + _ <- latch.release + offered <- offerFiber.get + backlog <- closeFiber.get + closedQueue <- queue.close + drained <- Abort.run(queue.drain) + isClosed <- queue.closed + yield + assert(backlog.isDefined) + assert(offered.count(_.contains(true)) == backlog.get.size) + assert(closedQueue.isEmpty) + assert(drained.isFail) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } - def peek()(using AllowUnsafe): Maybe[A] = - if closed then throw new IllegalStateException("Queue is closed") - else if elements.isEmpty then Maybe.empty - else Maybe(elements.head) + "offer and poll" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + queue <- Queue.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(queue.offer(i))))) + ) + pollFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => Abort.run(queue.poll)))) + ) + _ <- latch.release + offered <- offerFiber.get + polled <- pollFiber.get + left <- queue.size + yield assert(offered.count(_.contains(true)) == polled.count(_.toMaybe.flatten.isDefined) + left)) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } - def empty()(using AllowUnsafe): Boolean = - if closed then throw new IllegalStateException("Queue is closed") - else elements.isEmpty + "offer to full queue during close" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + queue <- Queue.init[Int](size) + _ <- Kyo.foreach(1 to size)(i => queue.offer(i)) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(queue.offer(i))))) + ) + closeFiber <- Async.run(latch.await.andThen(queue.close)) + _ <- latch.release + offered <- offerFiber.get + backlog <- closeFiber.get + isClosed <- queue.closed + yield + assert(backlog.isDefined) + assert(offered.count(_.contains(true)) == backlog.get.size - size) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } - def size()(using AllowUnsafe): Int = - if closed then throw new IllegalStateException("Queue is closed") - else elements.size + "concurrent close attempts" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + queue <- Queue.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(queue.offer(i))))) + ) + closeFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => queue.close))) + ) + _ <- latch.release + offered <- offerFiber.get + backlog <- closeFiber.get + isClosed <- queue.closed + yield + assert(backlog.count(_.isDefined) == 1) + assert(backlog.flatMap(_.toList.flatten).size == offered.count(_.contains(true))) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } - def full()(using AllowUnsafe): Boolean = - if closed then throw new IllegalStateException("Queue is closed") - else elements.size == capacity + "offer, poll and close" in run { + (for + size <- Choice.get(Seq(0, 1, 2, 10, 100)) + queue <- Queue.init[Int](size) + latch <- Latch.init(1) + offerFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(i => Abort.run(queue.offer(i))))) + ) + pollFiber <- Async.run( + latch.await.andThen(Async.parallel((1 to 100).map(_ => Abort.run(queue.poll)))) + ) + closeFiber <- Async.run(latch.await.andThen(queue.close)) + _ <- latch.release + offered <- offerFiber.get + polled <- pollFiber.get + backlog <- closeFiber.get + isClosed <- queue.closed + yield + assert(backlog.isDefined) + assert(offered.count(_.contains(true)) - polled.count(_.toMaybe.flatten.isDefined) == backlog.get.size) + assert(isClosed) + ) + .pipe(Choice.run).unit + .repeat(repeats) + .as(succeed) + } + } - end TestUnsafeQueue + if Platform.isJVM then + "non-power-of-two capacity" - { + import AllowUnsafe.embrace.danger + def testCapacity(accessType: Access) = + List(3, 7, 15, 31).foreach { requestedCapacity => + s"$accessType with requested capacity $requestedCapacity" in pendingUntilFixed { + val queue = Queue.Unsafe.init[Int](requestedCapacity, accessType) + val actualCapacity = queue.capacity + + (1 to requestedCapacity).foreach { i => + assert(queue.offer(i).contains(true), s"Failed to offer item $i") + } + + assert(!queue.offer(requestedCapacity + 1).contains(true), "Should not be able to offer beyond requested capacity") + assert( + actualCapacity >= requestedCapacity, + s"Actual capacity $actualCapacity is less than requested capacity $requestedCapacity" + ) + () + } + } + + Access.values.foreach { access => + access.toString - testCapacity(access) + } + } + end if end QueueTest diff --git a/kyo-data/shared/src/main/scala/kyo/Result.scala b/kyo-data/shared/src/main/scala/kyo/Result.scala index 0b35a409a..2dc18a6bc 100644 --- a/kyo-data/shared/src/main/scala/kyo/Result.scala +++ b/kyo-data/shared/src/main/scala/kyo/Result.scala @@ -568,6 +568,14 @@ object Result: ): Try[A] = fold(e => scala.util.Failure(e.getFailure.asInstanceOf[Throwable]))(scala.util.Success(_)) + /** Converts the Result to a Result[E, Unit]. + * + * @return + * A new Result with the same error type E and Unit as the success type + */ + def unit: Result[E, Unit] = + map(_ => ()) + /** Swaps the success and failure cases of the Result. * * @return @@ -591,6 +599,26 @@ object Result: case Success(`value`) => true case _ => false + /** Checks if the Result is a Success and the predicate holds for its value. + * + * @param pred + * The predicate function to apply to the successful value + * @return + * true if the Result is a Success and the predicate holds, false otherwise + */ + def exists(pred: A => Boolean): Boolean = + fold(_ => false)(pred) + + /** Checks if the Result is a Success and the predicate holds for its value, or if the Result is a Failure. + * + * @param pred + * The predicate function to apply to the successful value + * @return + * true if the Result is a Failure, or if it's a Success and the predicate holds + */ + def forall(pred: A => Boolean): Boolean = + fold(_ => true)(pred) + /** Returns a string representation of the Result. * * @return diff --git a/kyo-data/shared/src/test/scala/kyo/ResultTest.scala b/kyo-data/shared/src/test/scala/kyo/ResultTest.scala index 4d8bab04d..8aa750058 100644 --- a/kyo-data/shared/src/test/scala/kyo/ResultTest.scala +++ b/kyo-data/shared/src/test/scala/kyo/ResultTest.scala @@ -922,6 +922,68 @@ class ResultTest extends Test: } } + "unit" - { + "should convert Success to Success(())" in { + val result = Result.success(42).unit + assert(result == Success(())) + } + + "should not change Fail" in { + val result = Result.fail[String, Int]("error").unit + assert(result == Fail("error")) + } + + "should not change Panic" in { + val ex = new Exception("test") + val result = Result.panic[String, Int](ex).unit + assert(result == Panic(ex)) + } + } + + "exists" - { + "should return true for Success when predicate holds" in { + val result = Result.success(42) + assert(result.exists(_ > 0)) + } + + "should return false for Success when predicate doesn't hold" in { + val result = Result.success(42) + assert(!result.exists(_ < 0)) + } + + "should return false for Fail" in { + val result = Result.fail[String, Int]("error") + assert(!result.exists(_ => true)) + } + + "should return false for Panic" in { + val result = Result.panic[String, Int](new Exception("test")) + assert(!result.exists(_ => true)) + } + } + + "forall" - { + "should return true for Success when predicate holds" in { + val result = Result.success(42) + assert(result.forall(_ > 0)) + } + + "should return false for Success when predicate doesn't hold" in { + val result = Result.success(42) + assert(!result.forall(_ < 0)) + } + + "should return true for Fail" in { + val result = Result.fail[String, Int]("error") + assert(result.forall(_ => false)) + } + + "should return true for Panic" in { + val result = Result.panic[String, Int](new Exception("test")) + assert(result.forall(_ => false)) + } + } + "show" - { "Success" in { assert(Result.success(42).show == "Success(42)") diff --git a/kyo-examples/jvm/src/main/scala/examples/ledger/db/Log.scala b/kyo-examples/jvm/src/main/scala/examples/ledger/db/Log.scala index 7c1564b5d..adeddb8e3 100644 --- a/kyo-examples/jvm/src/main/scala/examples/ledger/db/Log.scala +++ b/kyo-examples/jvm/src/main/scala/examples/ledger/db/Log.scala @@ -21,7 +21,7 @@ object Log: val init: Log < (Env[DB.Config] & IO) = defer { val cfg = await(Env.get[DB.Config]) - val q = await(Queue.initUnbounded[Entry](Access.MultiProducerSingleConsumer)) + val q = await(Queue.Unbounded.init[Entry](Access.MultiProducerSingleConsumer)) val log = await(IO(Live(cfg.workingDir + "/log.dat", q))) val _ = await(Async.run(log.flushLoop(cfg.flushInterval))) log diff --git a/kyo-sttp/js/src/main/scala/kyo/PlatformBackend.scala b/kyo-sttp/js/src/main/scala/kyo/PlatformBackend.scala index c05d211a4..5e2f311dc 100644 --- a/kyo-sttp/js/src/main/scala/kyo/PlatformBackend.scala +++ b/kyo-sttp/js/src/main/scala/kyo/PlatformBackend.scala @@ -7,7 +7,7 @@ object PlatformBackend: val default = new Backend: val b = FetchBackend() - def send[A](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, Any]) = given Frame = Frame.internal Abort.run(Async.fromFuture(r.send(b))) .map(_.fold(ex => Abort.fail(FailedRequest(ex.getFailure)))(identity)) diff --git a/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala b/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala index 956af6c64..5e984c6f6 100644 --- a/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala +++ b/kyo-sttp/jvm/src/main/scala/kyo/PlatformBackend.scala @@ -10,7 +10,7 @@ object PlatformBackend: def apply(backend: SttpBackend[KyoSttpMonad.M, WebSockets])(using Frame): Backend = new Backend: - def send[A](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, Any]) = r.send(backend) def apply(client: HttpClient)(using Frame): Backend = @@ -19,6 +19,6 @@ object PlatformBackend: val default = new Backend: val b = HttpClientKyoBackend() - def send[A](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, Any]) = r.send(b) end PlatformBackend diff --git a/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSequencer.scala b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSequencer.scala index ceaa3f9bd..301f783ad 100644 --- a/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSequencer.scala +++ b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSequencer.scala @@ -7,5 +7,6 @@ import sttp.client3.internal.httpclient.Sequencer class KyoSequencer(mutex: Meter) extends Sequencer[KyoSttpMonad.M]: def apply[A](t: => KyoSttpMonad.M[A]) = - mutex.run(t) + import Flat.unsafe.bypass + Abort.run(mutex.run(t)).map(_.getOrThrow) // safe since the meter will never be closed end KyoSequencer diff --git a/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSimpleQueue.scala b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSimpleQueue.scala index f8eeb131e..cea0e4db1 100644 --- a/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSimpleQueue.scala +++ b/kyo-sttp/jvm/src/main/scala/sttp/client3/KyoSimpleQueue.scala @@ -9,10 +9,11 @@ class KyoSimpleQueue[A](ch: Channel[A]) extends SimpleQueue[KyoSttpMonad.M, A]: def offer(t: A): Unit = import kyo.AllowUnsafe.embrace.danger - if !IO.Unsafe.run(ch.offer(t)).eval then + if ch.unsafe.offer(t).contains(false) then throw WebSocketBufferFull(Int.MaxValue) end offer def poll = - ch.take + import kyo.AllowUnsafe.embrace.danger + ch.unsafe.takeFiber().mapResult(_.fold(e => throw e.exception)(Result.success)).safe.get end KyoSimpleQueue diff --git a/kyo-sttp/shared/src/main/scala/kyo/Requests.scala b/kyo-sttp/shared/src/main/scala/kyo/Requests.scala index bc09d1714..a8479efb5 100644 --- a/kyo-sttp/shared/src/main/scala/kyo/Requests.scala +++ b/kyo-sttp/shared/src/main/scala/kyo/Requests.scala @@ -31,7 +31,7 @@ object Requests: * @return * The response wrapped in an effect */ - def send[A](r: Request[A, Any]): Response[A] < (Async & Abort[FailedRequest]) + def send[A: Flat](r: Request[A, Any]): Response[A] < (Async & Abort[FailedRequest]) /** Wraps the Backend with a meter * @@ -42,8 +42,8 @@ object Requests: */ def withMeter(m: Meter)(using Frame): Backend = new Backend: - def send[A](r: Request[A, Any]) = - m.run(self.send(r)) + def send[A: Flat](r: Request[A, Any]) = + Abort.run(m.run(self.send(r))).map(r => Abort.get(r.mapFail(FailedRequest(_)))) end Backend /** The default live backend implementation */ diff --git a/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala b/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala index 1816db106..f0fd6c80b 100644 --- a/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala +++ b/kyo-sttp/shared/src/test/scala/kyo/RequestsTest.scala @@ -7,7 +7,7 @@ class RequestsTest extends Test: class TestBackend extends Requests.Backend: var calls = 0 - def send[A](r: Request[A, Any]) = + def send[A: Flat](r: Request[A, Any]) = calls += 1 Response.ok(Right("mocked")).asInstanceOf[Response[A]] end TestBackend @@ -49,7 +49,10 @@ class RequestsTest extends Test: "with meter" in run { var calls = 0 val meter = new Meter: - def available(using Frame) = ??? + def capacity = ??? + def availablePermits(using Frame) = ??? + def pendingWaiters(using Frame) = ??? + def closed(using Frame) = ??? def tryRun[A, S](v: => A < S)(using Frame) = ??? def run[A, S](v: => A < S)(using Frame) = calls += 1