Skip to content

Groupwithin improvements #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 59 additions & 98 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1378,12 +1378,12 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
}

/** Splits this stream into a stream of chunks of elements, such that
* 1. each chunk in the output has at most `outputSize` elements, and
* 1. each chunk in the output has at most `chunkSize` elements, and
* 2. the concatenation of those chunks, which is obtained by calling
* `unchunks`, yields the same element sequence as this stream.
*
* As `this` stream emits input elements, the result stream them in a
* waiting buffer, until it has enough elements to emit next chunk.
* As `this` stream ingests input elements, they will be collected in a
* waiting buffer, until it has enough elements to emit the next chunk.
*
* To avoid holding input elements for too long, this method takes a
* `timeout`. This timeout is reset after each output chunk is emitted.
Expand All @@ -1403,113 +1403,74 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
* When the input stream terminates, any accumulated elements are emitted
* immediately in a chunk, even if `timeout` has not expired.
*
* If the chunkSize is equal to zero the stream will block until the
* timeout expires at which point it will terminate.
*
* @param chunkSize the maximum size of chunks emitted by resulting stream.
* @param timeout maximum time that input elements are held in the buffer
* before being emitted by the resulting stream.
*/
def groupWithin[F2[x] >: F[x]](
chunkSize: Int,
timeout: FiniteDuration
)(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] = {

case class JunctionBuffer[T](
data: Vector[T],
endOfSupply: Option[Either[Throwable, Unit]],
endOfDemand: Option[Either[Throwable, Unit]]
) {
def splitAt(n: Int): (JunctionBuffer[T], JunctionBuffer[T]) =
if (this.data.size >= n) {
val (head, tail) = this.data.splitAt(n.toInt)
(this.copy(tail), this.copy(head))
} else {
(this.copy(Vector.empty), this)
}
}

val outputLong = chunkSize.toLong
fs2.Stream.force {
for {
demand <- Semaphore[F2](outputLong)
supply <- Semaphore[F2](0L)
buffer <- Ref[F2].of(
JunctionBuffer[O](Vector.empty[O], endOfSupply = None, endOfDemand = None)
)
} yield {
/* - Buffer: stores items from input to be sent on next output chunk
* - Demand Semaphore: to avoid adding too many items to buffer
* - Supply: counts filled positions for next output chunk */
def enqueue(t: O): F2[Boolean] =
for {
_ <- demand.acquire
buf <- buffer.modify(buf => (buf.copy(buf.data :+ t), buf))
_ <- supply.release
} yield buf.endOfDemand.isEmpty

val dequeueNextOutput: F2[Option[Vector[O]]] = {
// Trigger: waits until the supply buffer is full (with acquireN)
val waitSupply = supply.acquireN(outputLong).guaranteeCase {
case Outcome.Succeeded(_) => supply.releaseN(outputLong)
case _ => F.unit
}
)(implicit F: Temporal[F2]): Stream[F2, Chunk[O]] =
if (chunkSize == 0) Stream.sleep_[F2](timeout)
else if (timeout.toNanos == 0 || chunkSize == 1) chunkN(chunkSize)
else
Stream.force {
for {
supply <- Semaphore[F2](0)
buffer <- Ref[F2].empty[Vector[O]]
backpressure <- Semaphore[F2](chunkSize.toLong)
supplyEnded <- SignallingRef[F2].of(false)
} yield {

val onTimeout: F2[Long] =
for {
_ <- supply.acquire // waits until there is at least one element in buffer
m <- supply.available
k = m.min(outputLong - 1)
b <- supply.tryAcquireN(k)
} yield if (b) k + 1 else 1

// in JS cancellation doesn't always seem to run, so race conditions should restore state on their own
for {
acq <- F.race(F.sleep(timeout), waitSupply).flatMap {
case Left(_) => onTimeout
case Right(_) => supply.acquireN(outputLong).as(outputLong)
}
buf <- buffer.modify(_.splitAt(acq.toInt))
_ <- demand.releaseN(buf.data.size.toLong)
res <- buf.endOfSupply match {
case Some(Left(error)) => F.raiseError(error)
case Some(Right(_)) if buf.data.isEmpty => F.pure(None)
case _ => F.pure(Some(buf.data))
def push(o: O): F2[Unit] =
backpressure.acquire *> buffer.update(_ :+ o)

val flush: F2[Vector[O]] =
buffer.getAndSet(Vector.empty).flatTap(os => backpressure.releaseN(os.size.toLong))

// wait until the first chunk becomes available or when we reach the end of the stream.
val awaitSupply: F2[Unit] =
Stream.exec(supply.acquire).interruptWhen(supplyEnded).compile.drain

// in order to ensure prompt termination on interruption when the timeout has not kicked
// in yet or if we haven't seen enough elements we need provide enough supply for 2 iterations
val endSupply: F2[Unit] = supplyEnded.set(true) *> supply.releaseN(chunkSize * 2L)

// flush immediately or wait before doing so, subsequently lowering the supply by however
// many elements have been flushed (excluding the element already awaited, if needed)
def flushOnSupplyReceived(noSupply: Boolean): F2[Vector[O]] = for {
flushed <- awaitSupply.whenA(noSupply) *> flush
awaitedCount = if (noSupply) 1L else 0L
_ <- supply.acquireN((flushed.size - awaitedCount).max(0))
} yield flushed

// edge case: supply semaphore loses the race, but acquires the permits. In such scenario
// we flush the buffer without lowering the supply, since it has already been lowered
val onTimeout: F2[Vector[O]] = for {
bufferFull <- buffer.get.map(_.size == chunkSize)
noSupply <- supply.available.map(_ == 0)
edgeCase = bufferFull && noSupply
flushed <- if (edgeCase) flush else flushOnSupplyReceived(noSupply)
} yield flushed

val enqueue: F2[Unit] =
foreach(push(_) *> supply.release).compile.drain.guarantee(endSupply)

val dequeue: F2[Vector[O]] =
F.race(supply.acquireN(chunkSize.toLong), F.sleep(timeout)).flatMap {
case Left(_) => flush
case Right(_) => onTimeout
}
} yield res
}

def endSupply(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue)

def endDemand(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue)

def toEnding(ec: ExitCase): Either[Throwable, Unit] = ec match {
case ExitCase.Succeeded => Right(())
case ExitCase.Errored(e) => Left(e)
case ExitCase.Canceled => Right(())
}

val enqueueAsync = F.start {
this
.evalMap(enqueue)
.forall(identity)
.onFinalizeCase(ec => endSupply(toEnding(ec)))
.compile
.drain
}

val outputStream: Stream[F2, Chunk[O]] =
Stream
.eval(dequeueNextOutput)
.repeat
.collectWhile { case Some(data) => Chunk.vector(data) }

Stream
.bracketCase(enqueueAsync) { case (upstream, exitCase) =>
endDemand(toEnding(exitCase)) *> upstream.cancel
} >> outputStream
.repeatEval(dequeue)
.collectWhile { case os if os.nonEmpty => Chunk.vector(os) }
.concurrently(Stream.eval(enqueue))
}
}
}
}

/** If `this` terminates with `Stream.raiseError(e)`, invoke `h(e)`.
*
Expand Down
3 changes: 3 additions & 0 deletions core/shared/src/test/scala/fs2/Fs2Suite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ abstract class Fs2Suite
expect <- expected.compile.toList
} yield assertEquals(actual.toSet, expect.toSet)

def assertCompletes: IO[Unit] =
str.compile.drain.assert

def intercept[T <: Throwable](implicit T: ClassTag[T], loc: Location): IO[T] =
str.compile.drain.intercept[T]
}
Expand Down
151 changes: 149 additions & 2 deletions core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package fs2

import cats.effect.kernel.Deferred
import cats.effect.kernel.Ref
import cats.effect.std.{Semaphore, Queue}
import cats.effect.std.{Queue, Semaphore}
import cats.effect.testkit.TestControl
import cats.effect.{IO, SyncIO}
import cats.syntax.all._
Expand All @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll

import scala.concurrent.duration._
import scala.concurrent.TimeoutException
import scala.util.control.NoStackTrace

class StreamCombinatorsSuite extends Fs2Suite {
override def munitIOTimeout = 1.minute
Expand Down Expand Up @@ -748,7 +749,7 @@ class StreamCombinatorsSuite extends Fs2Suite {
}
}

test("accumulation and splitting".flaky) {
test("accumulation and splitting") {
val t = 200.millis
val size = 5
val sleep = Stream.sleep_[IO](2 * t)
Expand All @@ -775,6 +776,36 @@ class StreamCombinatorsSuite extends Fs2Suite {
source.groupWithin(size, t).map(_.toList).assertEmits(expected)
}

test("accumulation and splitting 2") {
val t = 200.millis
val size = 5
val sleep = Stream.sleep_[IO](2 * t)
val longSleep = sleep.repeatN(5)

def chunk(from: Int, size: Int) =
Stream.range(from, from + size).chunkAll.unchunks

// this test example is designed to have good coverage of
// the chunk manipulation logic in groupWithin
val source =
chunk(from = 1, size = 3) ++
sleep ++
chunk(from = 4, size = 1) ++ longSleep ++
chunk(from = 5, size = 11) ++
chunk(from = 16, size = 7)

val expected = List(
List(1, 2, 3),
List(4),
List(5, 6, 7, 8, 9),
List(10, 11, 12, 13, 14),
List(15, 16, 17, 18, 19),
List(20, 21, 22)
)

source.groupWithin(size, t).map(_.toList).assertEmits(expected)
}

test("does not reset timeout if nothing is emitted") {
TestControl
.executeEmbed(
Expand Down Expand Up @@ -834,6 +865,122 @@ class StreamCombinatorsSuite extends Fs2Suite {
)
.assertEquals(0.millis)
}

test("stress test (short execution): all elements are processed") {

val rangeLength = 100000

Stream
.eval(Ref[IO].of(0))
.flatMap { counter =>
Stream
.range(0, rangeLength)
.covary[IO]
.groupWithin(256, 100.micros)
.evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get)
}
.compile
.lastOrError
.assertEquals(rangeLength)

}

// ignoring because it's a (relatively) long running test (around 3/4 minutes), but it's useful
// to asses the validity of permits management and timeout logic over an extended period of time
test("stress test (long execution): all elements are processed".ignore) {
val rangeLength = 10000000

Stream
.eval(Ref[IO].of(0))
.flatMap { counter =>
Stream
.range(0, rangeLength)
.covary[IO]
.evalTap(d => IO.sleep((d % 10 + 2).micros))
.groupWithin(275, 5.millis)
.evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get)
}
.compile
.lastOrError
.assertEquals(rangeLength)
}

test("upstream failures are propagated downstream") {

case object SevenNotAllowed extends NoStackTrace

val source = Stream
.unfold(0)(s => Some((s, s + 1)))
.covary[IO]
.evalMap(n => if (n == 7) IO.raiseError(SevenNotAllowed) else IO.pure(n))

val downstream = source.groupWithin(100, 2.seconds)

downstream.compile.lastOrError.intercept[SevenNotAllowed.type]
}

test(
"upstream interruption causes immediate downstream termination with all elements being emitted"
) {

val sourceTimeout = 5.5.seconds
val downstreamTimeout = sourceTimeout + 2.seconds

TestControl
.executeEmbed(
Ref[IO]
.of(0.millis)
.flatMap { ref =>
val source: Stream[IO, Int] =
Stream
.unfold(0)(s => Some((s, s + 1)))
.covary[IO]
.meteredStartImmediately(1.second)
.interruptAfter(sourceTimeout)

// large chunkSize and timeout (no emissions expected in the window
// specified, unless source ends, due to interruption or
// natural termination (i.e runs out of elements)
val downstream: Stream[IO, Chunk[Int]] =
source.groupWithin(Int.MaxValue, 1.day)

downstream.compile.lastOrError
.map(_.toList)
.timeout(downstreamTimeout)
.flatTap(_ => IO.monotonic.flatMap(ref.set))
.flatMap(emit => ref.get.map(timeLapsed => (timeLapsed, emit)))
}
)
.assertEquals(
// downstream ended immediately (i.e timeLapsed = sourceTimeout)
// emitting whatever was accumulated at the time of interruption
(sourceTimeout, List(0, 1, 2, 3, 4, 5))
)
}

test(
"Edge case: if the buffer fills up and timeout expires at the same time there won't be a deadlock"
) {

forAllF { (s0: Stream[Pure, Int], b: Byte) =>
TestControl
.executeEmbed {

// preventing empty or singleton streams that would bypass the logic being tested
val n = b.max(2).toInt
val s = s0 ++ Stream.range(0, n)

// the buffer will reach its full capacity every
// n seconds exactly when the timeout expires
s
.covary[IO]
.metered(1.second)
.groupWithin(n, n.seconds)
.map(_.toList)
.assertCompletes
}
}
}
}

property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))
Expand Down