diff --git a/amazon-kinesis-client/pom.xml b/amazon-kinesis-client/pom.xml
index 551fe0a40..09c0ace38 100644
--- a/amazon-kinesis-client/pom.xml
+++ b/amazon-kinesis-client/pom.xml
@@ -122,6 +122,13 @@
+
+ org.awaitility
+ awaitility
+ 3.0.0
+ test
+
+
junit
junit
diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
index 42f88b12f..82acbf5e1 100644
--- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
+++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java
@@ -32,7 +32,9 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -45,6 +47,7 @@
import java.util.Optional;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
@@ -53,8 +56,10 @@
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
+import org.awaitility.Awaitility;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
@@ -62,7 +67,9 @@
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
@@ -148,6 +155,7 @@ public class ShardConsumerTest {
@Before
public void before() {
+ MockitoAnnotations.initMocks(this);
shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d")
.setDaemon(true).build();
@@ -848,6 +856,106 @@ public void testLongRunningTasks() throws Exception {
verifyNoMoreInteractions(taskExecutionListener);
}
+ @Test
+ public void testEmptyShardProcessingRaceCondition() throws Exception {
+ RecordsPublisher mockPublisher = mock(RecordsPublisher.class);
+ ExecutorService mockExecutor = mock(ExecutorService.class);
+ ConsumerState mockState = mock(ConsumerState.class);
+ ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L),
+ shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0);
+
+ when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
+ when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
+ ConsumerTask mockTask = mock(ConsumerTask.class);
+ when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
+ when(mockTask.call()).thenReturn(new TaskResult(false));
+
+ // Invoke async processing of blocked on parent task
+ consumer.executeLifecycle();
+ ArgumentCaptor taskToExecute = ArgumentCaptor.forClass(Runnable.class);
+ verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
+ taskToExecute.getValue().run();
+ reset(mockExecutor);
+
+ // move to initializing state and
+ // Invoke async processing of initialize state
+ when(mockState.successTransition()).thenReturn(mockState);
+ when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING);
+ when(mockState.taskType()).thenReturn(TaskType.INITIALIZE);
+ consumer.executeLifecycle();
+ verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
+ taskToExecute.getValue().run();
+
+ // Move to processing state
+ // and complete initialization future successfully
+ when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING);
+ consumer.executeLifecycle();
+
+ // Simulate the race where
+ // scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber)
+ // on recordProcessor thread
+ // but before scheduler thread finishes initialization, handleInput is invoked
+ // on record processor thread.
+
+ // Since ShardConsumer creates its own instance of subscriber that cannot be mocked
+ // this test sequence will appear a little odd.
+ // In order to control the order in which execution occurs, lets first invoke
+ // handleInput, although this will never happen, since there isn't a way
+ // to control the precise timing of the thread execution, this is the best way
+ CountDownLatch processTaskLatch = new CountDownLatch(1);
+ new Thread(() -> {
+ reset(mockState);
+ when(mockState.taskType()).thenReturn(TaskType.PROCESS);
+ ConsumerTask mockProcessTask = mock(ConsumerTask.class);
+ when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
+ CountDownLatch waitForSubscribeLatch = new CountDownLatch(1);
+ when(mockProcessTask.call()).then(input -> {
+ // first we want to wait for subscribe to be called,
+ // but we cannot control the timing, so wait for 10 seconds
+ // to let the main thread invoke executeLifecyle which
+ // will perform subscribe
+ processTaskLatch.countDown();
+ log.info("Waiting for countdown latch");
+ waitForSubscribeLatch.await(10, TimeUnit.SECONDS);
+ log.info("Waiting for countdown latch - DONE");
+ // then return shard end result
+ return new TaskResult(true);
+ });
+ Subscription mockSubscription = mock(Subscription.class);
+ consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription);
+ }).start();
+
+ processTaskLatch.await();
+
+ // now invoke lifecycle which should invoke subscribe
+ // but since we cannot countdown the latch, the latch will timeout
+ // meanwhile if scheduler tries to acquire the ShardConsumer lock it will
+ // be blocked during initialization processing. Thereby creating the
+ // race condition we want.
+ reset(mockState);
+ AtomicBoolean successTransitionCalled = new AtomicBoolean(false);
+ when(mockState.successTransition()).then(input -> {
+ successTransitionCalled.set(true);
+ return mockState;
+ });
+ AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false);
+ when(mockState.shutdownTransition(any())).then(input -> {
+ shutdownTransitionCalled.set(true);
+ return mockState;
+ });
+ when(mockState.state()).then(input -> {
+ if (successTransitionCalled.get() && shutdownTransitionCalled.get()) {
+ return ShardConsumerState.SHUTTING_DOWN;
+ }
+ return ShardConsumerState.PROCESSING;
+ });
+ consumer.executeLifecycle();
+ // initialization should be done by now, make sure shard consumer did not
+ // perform shutdown processing yet.
+ verify(mockState, times(0)).shutdownTransition(any());
+ }
+
+
private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
mockSuccessfulShutdown(taskCallBarrier, null);
}