Skip to content

Commit

Permalink
Fix batching in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
patchwork01 committed Aug 12, 2024
1 parent 80c3276 commit ddde8bb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 35 deletions.
25 changes: 9 additions & 16 deletions java/core/src/main/java/sleeper/core/util/SplitIntoBatches.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,23 @@ public static <T> void reusingListOfSize(int batchSize, Stream<T> items, Consume
}

/**
* Splits a stream into lists of a given batch size. A new list will be created for each batch. Does not support
* parallel streams.
* Performs an operation in parallel on each batch of a given size. A new list will be created for each batch.
*
* @param <T> the item type
* @param batchSize the number of items to process in a batch
* @param items a stream of items to split into batches
* @return a stream of batches of the given size
* @param <T> the item type
* @param batchSize the number of items to process in a batch
* @param items a stream of items to split into batches
* @param operation an operation to perform on a batch of items
*/
public static <T> Stream<List<T>> toListsOfSize(int batchSize, Stream<T> items) {
public static <T> void inParallelBatchesOf(int batchSize, Stream<T> items, Consumer<List<T>> operation) {
if (batchSize < 1) {
throw new IllegalArgumentException("Batch size must be at least 1, found " + batchSize);
}
if (items.isParallel()) {
throw new IllegalArgumentException("Cannot split parallel stream");
}
StreamBatcher<T> batcher = new StreamBatcher<>(batchSize);
Stream<List<T>> fullBatches = items.flatMap(item -> {
items.sequential().flatMap(item -> {
batcher.add(item);
return batcher.takeBatchIfFull().stream();
});
Stream<List<T>> remainingPartialBatch = Stream.generate(
() -> batcher.takeBatchIfNotEmpty())
.limit(1).flatMap(Optional::stream);
return Stream.concat(fullBatches, remainingPartialBatch);
}).parallel().forEach(operation);
batcher.takeBatchIfNotEmpty().ifPresent(operation);
}

/**
Expand Down
41 changes: 25 additions & 16 deletions java/core/src/test/java/sleeper/core/util/SplitIntoBatchesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Stream;

Expand Down Expand Up @@ -131,49 +135,54 @@ private List<List<String>> splitToBatchesOf(int batchSize, Stream<String> stream
}

@Nested
@DisplayName("Split a stream into batches partitioning to multiple lists")
class SplitAStreamPartitioningToLists {
@DisplayName("Process streamed batches in parallel")
class ParallelBatches {

@Test
void shouldSplitIntoTwoFullBatches() {
assertThat(SplitIntoBatches.toListsOfSize(2, Stream.of("A", "B", "C", "D")))
.containsExactly(List.of("A", "B"), List.of("C", "D"));
assertThat(consumeParallelBatchesOf(2, Stream.of("A", "B", "C", "D")))
.containsExactlyInAnyOrder("A", "B", "C", "D");
}

@Test
void shouldSplitIntoOneFullBatchAndOnePartialBatchLeftOver() {
assertThat(SplitIntoBatches.toListsOfSize(2, Stream.of("A", "B", "C")))
.containsExactly(List.of("A", "B"), List.of("C"));
assertThat(consumeParallelBatchesOf(2, Stream.of("A", "B", "C")))
.containsExactlyInAnyOrder("A", "B", "C");
}

@Test
void shouldSplitIntoOneFullBatch() {
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of("A", "B", "C")))
.containsExactly(List.of("A", "B", "C"));
assertThat(consumeParallelBatchesOf(3, Stream.of("A", "B", "C")))
.containsExactlyInAnyOrder("A", "B", "C");
}

@Test
void shouldSplitIntoOnePartialBatch() {
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of("A", "B")))
.containsExactly(List.of("A", "B"));
assertThat(consumeParallelBatchesOf(3, Stream.of("A", "B")))
.containsExactlyInAnyOrder("A", "B");
}

@Test
void shouldSplitEmptyStreamToNoBatches() {
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of()))
assertThat(consumeParallelBatchesOf(3, Stream.of()))
.isEmpty();
}

@Test
void shouldFailWithBatchSizeLowerThanOne() {
assertThatThrownBy(() -> SplitIntoBatches.toListsOfSize(0, Stream.of("A", "B")))
Consumer<List<String>> notInvoked = batch -> {
throw new IllegalStateException("Did not expect operation to be called");
};
assertThatThrownBy(() -> SplitIntoBatches.inParallelBatchesOf(0, Stream.of("A", "B"), notInvoked))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
void shouldFailWithParallelStream() {
assertThatThrownBy(() -> SplitIntoBatches.toListsOfSize(0, Stream.of("A", "B").parallel()))
.isInstanceOf(IllegalArgumentException.class);
private Collection<String> consumeParallelBatchesOf(int batchSize, Stream<String> stream) {
Map<String, String> output = new ConcurrentHashMap<>();
SplitIntoBatches.inParallelBatchesOf(batchSize, stream, batch -> {
batch.forEach(value -> output.put(UUID.randomUUID().toString(), value));
});
return output.values();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ public AwsStateStoreCommitterDriver(SystemTestInstanceContext instance, AmazonSQ

@Override
public void sendCommitMessages(Stream<StateStoreCommitMessage> messages) {
SplitIntoBatches.toListsOfSize(10, messages)
.parallel().forEach(this::sendMessageBatch);
SplitIntoBatches.inParallelBatchesOf(10, messages, this::sendMessageBatch);
}

private void sendMessageBatch(List<StateStoreCommitMessage> batch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void shouldSendMoreCommitsThanBatchSize(SleeperSystemTest sleeper) {
String tableId = sleeper.tableProperties().get(TABLE_ID);
assertThat(receiveCommitRequestsForBatches(sleeper, 2))
.extracting(this::getMessageGroupId, this::readCommitRequest)
.containsExactlyElementsOf(files.stream().map(file -> tuple(tableId,
.containsExactlyInAnyOrderElementsOf(files.stream().map(file -> tuple(tableId,
StateStoreCommitRequest.forIngestAddFiles(IngestAddFilesCommitRequest.builder()
.tableId(tableId)
.fileReferences(List.of(file))
Expand Down

0 comments on commit ddde8bb

Please sign in to comment.