diff --git a/amazon-kinesis-client/pom.xml b/amazon-kinesis-client/pom.xml index 551fe0a40..09c0ace38 100644 --- a/amazon-kinesis-client/pom.xml +++ b/amazon-kinesis-client/pom.xml @@ -122,6 +122,13 @@ + + org.awaitility + awaitility + 3.0.0 + test + + junit junit diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java index 42f88b12f..82acbf5e1 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/lifecycle/ShardConsumerTest.java @@ -32,7 +32,9 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -45,6 +47,7 @@ import java.util.Optional; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -53,8 +56,10 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import org.awaitility.Awaitility; import org.junit.After; import org.junit.Before; import org.junit.Ignore; @@ -62,7 +67,9 @@ import org.junit.Test; import org.junit.rules.TestName; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.reactivestreams.Subscriber; @@ -148,6 +155,7 @@ public class ShardConsumerTest { @Before public void before() { + MockitoAnnotations.initMocks(this); shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON); ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d") .setDaemon(true).build(); @@ -848,6 +856,106 @@ public void testLongRunningTasks() throws Exception { verifyNoMoreInteractions(taskExecutionListener); } + @Test + public void testEmptyShardProcessingRaceCondition() throws Exception { + RecordsPublisher mockPublisher = mock(RecordsPublisher.class); + ExecutorService mockExecutor = mock(ExecutorService.class); + ConsumerState mockState = mock(ConsumerState.class); + ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L), + shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0); + + when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS); + when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS); + ConsumerTask mockTask = mock(ConsumerTask.class); + when(mockState.createTask(any(), any(), any())).thenReturn(mockTask); + when(mockTask.call()).thenReturn(new TaskResult(false)); + + // Invoke async processing of blocked on parent task + consumer.executeLifecycle(); + ArgumentCaptor taskToExecute = ArgumentCaptor.forClass(Runnable.class); + verify(mockExecutor, timeout(100)).execute(taskToExecute.capture()); + taskToExecute.getValue().run(); + reset(mockExecutor); + + // move to initializing state and + // Invoke async processing of initialize state + when(mockState.successTransition()).thenReturn(mockState); + when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING); + when(mockState.taskType()).thenReturn(TaskType.INITIALIZE); + consumer.executeLifecycle(); + verify(mockExecutor, timeout(100)).execute(taskToExecute.capture()); + taskToExecute.getValue().run(); + + // Move to processing state + // and complete initialization future successfully + when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING); + consumer.executeLifecycle(); + + // Simulate the race where + // scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber) + // on recordProcessor thread + // but before scheduler thread finishes initialization, handleInput is invoked + // on record processor thread. + + // Since ShardConsumer creates its own instance of subscriber that cannot be mocked + // this test sequence will appear a little odd. + // In order to control the order in which execution occurs, lets first invoke + // handleInput, although this will never happen, since there isn't a way + // to control the precise timing of the thread execution, this is the best way + CountDownLatch processTaskLatch = new CountDownLatch(1); + new Thread(() -> { + reset(mockState); + when(mockState.taskType()).thenReturn(TaskType.PROCESS); + ConsumerTask mockProcessTask = mock(ConsumerTask.class); + when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask); + CountDownLatch waitForSubscribeLatch = new CountDownLatch(1); + when(mockProcessTask.call()).then(input -> { + // first we want to wait for subscribe to be called, + // but we cannot control the timing, so wait for 10 seconds + // to let the main thread invoke executeLifecyle which + // will perform subscribe + processTaskLatch.countDown(); + log.info("Waiting for countdown latch"); + waitForSubscribeLatch.await(10, TimeUnit.SECONDS); + log.info("Waiting for countdown latch - DONE"); + // then return shard end result + return new TaskResult(true); + }); + Subscription mockSubscription = mock(Subscription.class); + consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription); + }).start(); + + processTaskLatch.await(); + + // now invoke lifecycle which should invoke subscribe + // but since we cannot countdown the latch, the latch will timeout + // meanwhile if scheduler tries to acquire the ShardConsumer lock it will + // be blocked during initialization processing. Thereby creating the + // race condition we want. + reset(mockState); + AtomicBoolean successTransitionCalled = new AtomicBoolean(false); + when(mockState.successTransition()).then(input -> { + successTransitionCalled.set(true); + return mockState; + }); + AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false); + when(mockState.shutdownTransition(any())).then(input -> { + shutdownTransitionCalled.set(true); + return mockState; + }); + when(mockState.state()).then(input -> { + if (successTransitionCalled.get() && shutdownTransitionCalled.get()) { + return ShardConsumerState.SHUTTING_DOWN; + } + return ShardConsumerState.PROCESSING; + }); + consumer.executeLifecycle(); + // initialization should be done by now, make sure shard consumer did not + // perform shutdown processing yet. + verify(mockState, times(0)).shutdownTransition(any()); + } + + private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) { mockSuccessfulShutdown(taskCallBarrier, null); }