Skip to content

Commit

Permalink
Add SplitIntoBatches.toListsOfSize
Browse files Browse the repository at this point in the history
  • Loading branch information
patchwork01 committed Aug 12, 2024
1 parent b50a97d commit 09698a9
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
65 changes: 65 additions & 0 deletions java/core/src/main/java/sleeper/core/util/SplitIntoBatches.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -73,4 +74,68 @@ public static <T> void reusingListOfSize(int batchSize, Stream<T> items, Consume
operation.accept(batch);
}
}

/**
* Splits a stream into lists of a given batch size. A new list will be created for each batch. Does not support
* parallel streams.
*
* @param <T> the item type
* @param batchSize the number of items to process in a batch
* @param items a stream of items to split into batches
* @return a stream of batches of the given size
*/
public static <T> Stream<List<T>> toListsOfSize(int batchSize, Stream<T> items) {
if (batchSize < 1) {
throw new IllegalArgumentException("Batch size must be at least 1, found " + batchSize);
}
if (items.isParallel()) {
throw new IllegalArgumentException("Cannot split parallel stream");
}
StreamBatcher<T> batcher = new StreamBatcher<>(batchSize);
Stream<List<T>> fullBatches = items.flatMap(item -> {
batcher.add(item);
return batcher.takeBatchIfFull().stream();
});
Stream<List<T>> remainingPartialBatch = Stream.generate(
() -> batcher.takeBatchIfNotEmpty())
.limit(1).flatMap(Optional::stream);
return Stream.concat(fullBatches, remainingPartialBatch);
}

/**
* Partitions a non-parallel stream into batches.
*
* @param <T> the item type
*/
private static class StreamBatcher<T> {
private final int batchSize;
private List<T> batch;

private StreamBatcher(int batchSize) {
this.batchSize = batchSize;
batch = new ArrayList<>(batchSize);
}

void add(T item) {
batch.add(item);
}

Optional<List<T>> takeBatchIfFull() {
if (batch.size() == batchSize) {
List<T> fullBatch = batch;
batch = new ArrayList<>(batchSize);
return Optional.of(fullBatch);
} else {
return Optional.empty();
}
}

Optional<List<T>> takeBatchIfNotEmpty() {
if (batch.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(batch);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ void shouldFailWithBatchSizeLowerThanOne() {
}

@Nested
@DisplayName("Split a stream into batches")
class SplitAStream {
@DisplayName("Split a stream into batches reusing a list")
class SplitAStreamReusingList {

@Test
void shouldSplitIntoOneFullBatchAndOnePartialBatchLeftOver() {
Expand Down Expand Up @@ -129,4 +129,51 @@ private List<List<String>> splitToBatchesOf(int batchSize, Stream<String> stream
return batches;
}
}

@Nested
@DisplayName("Split a stream into batches partitioning to multiple lists")
class SplitAStreamPartitioningToLists {

@Test
void shouldSplitIntoTwoFullBatches() {
assertThat(SplitIntoBatches.toListsOfSize(2, Stream.of("A", "B", "C", "D")))
.containsExactly(List.of("A", "B"), List.of("C", "D"));
}

@Test
void shouldSplitIntoOneFullBatchAndOnePartialBatchLeftOver() {
assertThat(SplitIntoBatches.toListsOfSize(2, Stream.of("A", "B", "C")))
.containsExactly(List.of("A", "B"), List.of("C"));
}

@Test
void shouldSplitIntoOneFullBatch() {
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of("A", "B", "C")))
.containsExactly(List.of("A", "B", "C"));
}

@Test
void shouldSplitIntoOnePartialBatch() {
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of("A", "B")))
.containsExactly(List.of("A", "B"));
}

@Test
void shouldSplitEmptyStreamToNoBatches() {
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of()))
.isEmpty();
}

@Test
void shouldFailWithBatchSizeLowerThanOne() {
assertThatThrownBy(() -> SplitIntoBatches.toListsOfSize(0, Stream.of("A", "B")))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
void shouldFailWithParallelStream() {
assertThatThrownBy(() -> SplitIntoBatches.toListsOfSize(0, Stream.of("A", "B").parallel()))
.isInstanceOf(IllegalArgumentException.class);
}
}
}

0 comments on commit 09698a9

Please sign in to comment.