Skip to content

Commit ddde8bb

Browse files
committed
Fix batching in parallel
1 parent 80c3276 commit ddde8bb

File tree

4 files changed

+36
-35
lines changed

4 files changed

+36
-35
lines changed

java/core/src/main/java/sleeper/core/util/SplitIntoBatches.java

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,23 @@ public static <T> void reusingListOfSize(int batchSize, Stream<T> items, Consume
7676
}
7777

7878
/**
79-
* Splits a stream into lists of a given batch size. A new list will be created for each batch. Does not support
80-
* parallel streams.
79+
* Performs an operation in parallel on each batch of a given size. A new list will be created for each batch.
8180
*
82-
* @param <T> the item type
83-
* @param batchSize the number of items to process in a batch
84-
* @param items a stream of items to split into batches
85-
* @return a stream of batches of the given size
81+
* @param <T> the item type
82+
* @param batchSize the number of items to process in a batch
83+
* @param items a stream of items to split into batches
84+
* @param operation an operation to perform on a batch of items
8685
*/
87-
public static <T> Stream<List<T>> toListsOfSize(int batchSize, Stream<T> items) {
86+
public static <T> void inParallelBatchesOf(int batchSize, Stream<T> items, Consumer<List<T>> operation) {
8887
if (batchSize < 1) {
8988
throw new IllegalArgumentException("Batch size must be at least 1, found " + batchSize);
9089
}
91-
if (items.isParallel()) {
92-
throw new IllegalArgumentException("Cannot split parallel stream");
93-
}
9490
StreamBatcher<T> batcher = new StreamBatcher<>(batchSize);
95-
Stream<List<T>> fullBatches = items.flatMap(item -> {
91+
items.sequential().flatMap(item -> {
9692
batcher.add(item);
9793
return batcher.takeBatchIfFull().stream();
98-
});
99-
Stream<List<T>> remainingPartialBatch = Stream.generate(
100-
() -> batcher.takeBatchIfNotEmpty())
101-
.limit(1).flatMap(Optional::stream);
102-
return Stream.concat(fullBatches, remainingPartialBatch);
94+
}).parallel().forEach(operation);
95+
batcher.takeBatchIfNotEmpty().ifPresent(operation);
10396
}
10497

10598
/**

java/core/src/test/java/sleeper/core/util/SplitIntoBatchesTest.java

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
import org.junit.jupiter.api.Test;
2222

2323
import java.util.ArrayList;
24+
import java.util.Collection;
2425
import java.util.List;
26+
import java.util.Map;
27+
import java.util.UUID;
28+
import java.util.concurrent.ConcurrentHashMap;
2529
import java.util.function.Consumer;
2630
import java.util.stream.Stream;
2731

@@ -131,49 +135,54 @@ private List<List<String>> splitToBatchesOf(int batchSize, Stream<String> stream
131135
}
132136

133137
@Nested
134-
@DisplayName("Split a stream into batches partitioning to multiple lists")
135-
class SplitAStreamPartitioningToLists {
138+
@DisplayName("Process streamed batches in parallel")
139+
class ParallelBatches {
136140

137141
@Test
138142
void shouldSplitIntoTwoFullBatches() {
139-
assertThat(SplitIntoBatches.toListsOfSize(2, Stream.of("A", "B", "C", "D")))
140-
.containsExactly(List.of("A", "B"), List.of("C", "D"));
143+
assertThat(consumeParallelBatchesOf(2, Stream.of("A", "B", "C", "D")))
144+
.containsExactlyInAnyOrder("A", "B", "C", "D");
141145
}
142146

143147
@Test
144148
void shouldSplitIntoOneFullBatchAndOnePartialBatchLeftOver() {
145-
assertThat(SplitIntoBatches.toListsOfSize(2, Stream.of("A", "B", "C")))
146-
.containsExactly(List.of("A", "B"), List.of("C"));
149+
assertThat(consumeParallelBatchesOf(2, Stream.of("A", "B", "C")))
150+
.containsExactlyInAnyOrder("A", "B", "C");
147151
}
148152

149153
@Test
150154
void shouldSplitIntoOneFullBatch() {
151-
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of("A", "B", "C")))
152-
.containsExactly(List.of("A", "B", "C"));
155+
assertThat(consumeParallelBatchesOf(3, Stream.of("A", "B", "C")))
156+
.containsExactlyInAnyOrder("A", "B", "C");
153157
}
154158

155159
@Test
156160
void shouldSplitIntoOnePartialBatch() {
157-
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of("A", "B")))
158-
.containsExactly(List.of("A", "B"));
161+
assertThat(consumeParallelBatchesOf(3, Stream.of("A", "B")))
162+
.containsExactlyInAnyOrder("A", "B");
159163
}
160164

161165
@Test
162166
void shouldSplitEmptyStreamToNoBatches() {
163-
assertThat(SplitIntoBatches.toListsOfSize(3, Stream.of()))
167+
assertThat(consumeParallelBatchesOf(3, Stream.of()))
164168
.isEmpty();
165169
}
166170

167171
@Test
168172
void shouldFailWithBatchSizeLowerThanOne() {
169-
assertThatThrownBy(() -> SplitIntoBatches.toListsOfSize(0, Stream.of("A", "B")))
173+
Consumer<List<String>> notInvoked = batch -> {
174+
throw new IllegalStateException("Did not expect operation to be called");
175+
};
176+
assertThatThrownBy(() -> SplitIntoBatches.inParallelBatchesOf(0, Stream.of("A", "B"), notInvoked))
170177
.isInstanceOf(IllegalArgumentException.class);
171178
}
172179

173-
@Test
174-
void shouldFailWithParallelStream() {
175-
assertThatThrownBy(() -> SplitIntoBatches.toListsOfSize(0, Stream.of("A", "B").parallel()))
176-
.isInstanceOf(IllegalArgumentException.class);
180+
private Collection<String> consumeParallelBatchesOf(int batchSize, Stream<String> stream) {
181+
Map<String, String> output = new ConcurrentHashMap<>();
182+
SplitIntoBatches.inParallelBatchesOf(batchSize, stream, batch -> {
183+
batch.forEach(value -> output.put(UUID.randomUUID().toString(), value));
184+
});
185+
return output.values();
177186
}
178187
}
179188
}

java/system-test/system-test-drivers/src/main/java/sleeper/systemtest/drivers/statestore/AwsStateStoreCommitterDriver.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ public AwsStateStoreCommitterDriver(SystemTestInstanceContext instance, AmazonSQ
4343

4444
@Override
4545
public void sendCommitMessages(Stream<StateStoreCommitMessage> messages) {
46-
SplitIntoBatches.toListsOfSize(10, messages)
47-
.parallel().forEach(this::sendMessageBatch);
46+
SplitIntoBatches.inParallelBatchesOf(10, messages, this::sendMessageBatch);
4847
}
4948

5049
private void sendMessageBatch(List<StateStoreCommitMessage> batch) {

java/system-test/system-test-drivers/src/test/java/sleeper/systemtest/drivers/statestore/AwsStateStoreCommitterDriverIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void shouldSendMoreCommitsThanBatchSize(SleeperSystemTest sleeper) {
9393
String tableId = sleeper.tableProperties().get(TABLE_ID);
9494
assertThat(receiveCommitRequestsForBatches(sleeper, 2))
9595
.extracting(this::getMessageGroupId, this::readCommitRequest)
96-
.containsExactlyElementsOf(files.stream().map(file -> tuple(tableId,
96+
.containsExactlyInAnyOrderElementsOf(files.stream().map(file -> tuple(tableId,
9797
StateStoreCommitRequest.forIngestAddFiles(IngestAddFilesCommitRequest.builder()
9898
.tableId(tableId)
9999
.fileReferences(List.of(file))

0 commit comments

Comments
 (0)