diff --git a/src/main/scala/util/Batcher.scala b/src/main/scala/util/Batcher.scala index a06c3914..b12e3a65 100644 --- a/src/main/scala/util/Batcher.scala +++ b/src/main/scala/util/Batcher.scala @@ -19,18 +19,25 @@ final class Batcher[Key, Elem, Batch]( private val buffers = ConcurrentHashMap[Key, Buffer](initialCapacity) def add(key: Key, elem: Elem): Unit = - val newBuffer = buffers.compute( + buffers.compute( key, (_, buffer) => val prev = Option(buffer) if prev.isEmpty then scheduler.scheduleOnce(timeout, () => emitAndRemove(key)) - Buffer( + val newBuffer = Buffer( append(prev.map(_.batch), elem), prev.fold(1)(_.counter + 1) ) + if newBuffer.counter >= maxBatchSize then + emit(key, newBuffer.batch) + null + else newBuffer ) - if newBuffer.counter >= maxBatchSize then emitAndRemove(key) private def emitAndRemove(key: Key): Unit = - Option(buffers.remove(key)).foreach: buffer => - emit(key, buffer.batch) + buffers.computeIfPresent( + key, + (_, buffer) => + emit(key, buffer.batch) + null + )