From b8229e051192ce9bd3da306fb42a505546374fcc Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Tue, 19 Sep 2023 12:07:02 -0700 Subject: [PATCH] reformat StreamingDataflowWorker, apply CL comments. --- .../worker/StreamingDataflowWorker.java | 304 ++++++++---------- .../worker/streaming/ComputationState.java | 14 +- .../streaming/WeightedBoundedQueue.java | 4 +- .../streaming/WeightBoundedQueueTest.java | 130 ++++++++ 4 files changed, 279 insertions(+), 173 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index e5c0eaf6973df..d6fc08e38ac98 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -40,6 +40,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Random; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -92,6 +93,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.StageInfo; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.Work.State; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter; @@ -121,7 +123,6 @@ import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; @@ -148,8 +149,29 @@ }) public class StreamingDataflowWorker { - private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); + // TODO(https://github.com/apache/beam/issues/19632): Update throttling counters to use generic + // throttling-msecs metric. + public static final MetricName BIGQUERY_STREAMING_INSERT_THROTTLE_TIME = + MetricName.named( + "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl", + "throttling-msecs"); + // Maximum number of threads for processing. Currently each thread processes one key at a time. + static final int MAX_PROCESSING_THREADS = 300; + static final long THREAD_EXPIRATION_TIME_SEC = 60; + static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; + static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB + static final int NUM_COMMIT_STREAMS = 1; + static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3; + static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1); + /** + * Sinks are marked 'full' in {@link StreamingModeExecutionContext} once the amount of data sinked + * (across all the sinks, if there are more than one) reaches this limit. This serves as hint for + * readers to stop producing more. This can be disabled with 'disable_limiting_bundle_sink_bytes' + * experiment. + */ + static final int MAX_SINK_BYTES = 10_000_000; + private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorker.class); /** The idGenerator to generate unique id globally. */ private static final IdGenerator idGenerator = IdGenerators.decrementingLongs(); /** @@ -158,7 +180,6 @@ public class StreamingDataflowWorker { */ private static final Function fixMultiOutputInfos = new FixMultiOutputInfosOnParDoInstructions(idGenerator); - /** * Function which converts map tasks to their network representation for execution. * @@ -170,106 +191,35 @@ public class StreamingDataflowWorker { private static final Function> mapTaskToBaseNetwork = new MapTaskToNetworkFunction(idGenerator); - private static Random clientIdGenerator = new Random(); - - // Maximum number of threads for processing. Currently each thread processes one key at a time. - static final int MAX_PROCESSING_THREADS = 300; - static final long THREAD_EXPIRATION_TIME_SEC = 60; - static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; - static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB - static final int NUM_COMMIT_STREAMS = 1; - static final int GET_WORK_STREAM_TIMEOUT_MINUTES = 3; - static final Duration COMMIT_STREAM_TIMEOUT = Duration.standardMinutes(1); - private static final int DEFAULT_STATUS_PORT = 8081; - // Maximum size of the result of a GetWork request. private static final long MAX_GET_WORK_FETCH_BYTES = 64L << 20; // 64m - // Reserved ID for counter updates. // Matches kWindmillCounterUpdate in workflow_worker_service_multi_hubs.cc. private static final String WINDMILL_COUNTER_UPDATE_WORK_ID = "3"; - /** Maximum number of failure stacktraces to report in each update sent to backend. */ private static final int MAX_FAILURES_TO_REPORT_IN_UPDATE = 1000; - // TODO(https://github.com/apache/beam/issues/19632): Update throttling counters to use generic - // throttling-msecs metric. - public static final MetricName BIGQUERY_STREAMING_INSERT_THROTTLE_TIME = - MetricName.named( - "org.apache.beam.sdk.io.gcp.bigquery.BigQueryServicesImpl$DatasetServiceImpl", - "throttling-msecs"); - private static final Duration MAX_LOCAL_PROCESSING_RETRY_DURATION = Duration.standardMinutes(5); - - /** Returns whether an exception was caused by a {@link OutOfMemoryError}. */ - private static boolean isOutOfMemoryError(Throwable t) { - while (t != null) { - if (t instanceof OutOfMemoryError) { - return true; - } - t = t.getCause(); - } - return false; - } - - private static MapTask parseMapTask(String input) throws IOException { - return Transport.getJsonFactory().fromString(input, MapTask.class); - } - - public static void main(String[] args) throws Exception { - JvmInitializers.runOnStartup(); - - DataflowWorkerHarnessHelper.initializeLogging(StreamingDataflowWorker.class); - DataflowWorkerHarnessOptions options = - DataflowWorkerHarnessHelper.initializeGlobalStateAndPipelineOptions( - StreamingDataflowWorker.class); - DataflowWorkerHarnessHelper.configureLogging(options); - checkArgument( - options.isStreaming(), - "%s instantiated with options indicating batch use", - StreamingDataflowWorker.class.getName()); - - checkArgument( - !DataflowRunner.hasExperiment(options, "beam_fn_api"), - "%s cannot be main() class with beam_fn_api enabled", - StreamingDataflowWorker.class.getSimpleName()); - - StreamingDataflowWorker worker = - StreamingDataflowWorker.fromDataflowWorkerHarnessOptions(options); - - // Use the MetricsLogger container which is used by BigQueryIO to periodically log process-wide - // metrics. - MetricsEnvironment.setProcessWideContainer(new MetricsLogger(null)); - - JvmInitializers.runBeforeProcessing(options); - worker.startStatusPages(); - worker.start(); - } - + private static final Random clientIdGenerator = new Random(); + final WindmillStateCache stateCache; // Maps from computation ids to per-computation state. private final ConcurrentMap computationMap = new ConcurrentHashMap<>(); private final WeightedBoundedQueue commitQueue = - new WeightedBoundedQueue<>( + WeightedBoundedQueue.create( MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); - // Cache of tokens to commit callbacks. // Using Cache with time eviction policy helps us to prevent memory leak when callback ids are // discarded by Dataflow service and calling commitCallback is best-effort. private final Cache commitCallbacks = CacheBuilder.newBuilder().expireAfterWrite(5L, TimeUnit.MINUTES).build(); - // Map of user state names to system state names. // TODO(drieber): obsolete stateNameMap. Use transformUserNameToStateFamily in // ComputationState instead. private final ConcurrentMap stateNameMap = new ConcurrentHashMap<>(); private final ConcurrentMap systemNameToComputationIdMap = new ConcurrentHashMap<>(); - - final WindmillStateCache stateCache; - private final ThreadFactory threadFactory; - private DataflowMapTaskExecutorFactory mapTaskExecutorFactory; private final BoundedQueueExecutor workUnitExecutor; private final WindmillServerStub windmillServer; private final Thread dispatchThread; @@ -280,16 +230,13 @@ public static void main(String[] args) throws Exception { private final StreamingDataflowWorkerOptions options; private final boolean windmillServiceEnabled; private final long clientId; - private final MetricTrackingWindmillServerStub metricTrackingWindmillServer; private final CounterSet pendingDeltaCounters = new CounterSet(); private final CounterSet pendingCumulativeCounters = new CounterSet(); private final java.util.concurrent.ConcurrentLinkedQueue pendingMonitoringInfos = new ConcurrentLinkedQueue<>(); - // Map from stage name to StageInfo containing metrics container registry and per stage counters. private final ConcurrentMap stageInfoMap = new ConcurrentHashMap(); - // Built-in delta counters. private final Counter windmillShuffleBytesRead; private final Counter windmillStateBytesRead; @@ -301,69 +248,35 @@ public static void main(String[] args) throws Exception { private final Counter timeAtMaxActiveThreads; private final Counter windmillMaxObservedWorkItemCommitBytes; private final Counter memoryThrashing; - private ScheduledExecutorService refreshWorkTimer; - private ScheduledExecutorService statusPageTimer; - private final boolean publishCounters; - private ScheduledExecutorService globalWorkerUpdatesTimer; - private int retryLocallyDelayMs = 10000; - - // Periodically fires a global config request to dataflow service. Only used when windmill service - // is enabled. - private ScheduledExecutorService globalConfigRefreshTimer; - private final MemoryMonitor memoryMonitor; private final Thread memoryMonitorThread; - private final WorkerStatusPages statusPages; - // Periodic sender of debug information to the debug capture service. - private DebugCapture.Manager debugCaptureManager = null; - // Limit on bytes sinked (committed) in a work item. private final long maxSinkBytes; // = MAX_SINK_BYTES unless disabled in options. - // Possibly overridden by streaming engine config. - private int maxWorkItemCommitBytes = Integer.MAX_VALUE; - private final EvictingQueue pendingFailuresToReport = - EvictingQueue.create(MAX_FAILURES_TO_REPORT_IN_UPDATE); - + EvictingQueue.create(MAX_FAILURES_TO_REPORT_IN_UPDATE); private final ReaderCache readerCache; - private final WorkUnitClient workUnitClient; private final CompletableFuture isDoneFuture; private final Function> mapTaskToNetwork; - - /** - * Sinks are marked 'full' in {@link StreamingModeExecutionContext} once the amount of data sinked - * (across all the sinks, if there are more than one) reaches this limit. This serves as hint for - * readers to stop producing more. This can be disabled with 'disable_limiting_bundle_sink_bytes' - * experiment. - */ - static final int MAX_SINK_BYTES = 10_000_000; - private final ReaderRegistry readerRegistry = ReaderRegistry.defaultRegistry(); private final SinkRegistry sinkRegistry = SinkRegistry.defaultRegistry(); - - private HotKeyLogger hotKeyLogger; - private final Supplier clock; private final Function executorSupplier; - - public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( - DataflowWorkerHarnessOptions options) throws IOException { - - return new StreamingDataflowWorker( - Collections.emptyList(), - IntrinsicMapTaskExecutorFactory.defaultFactory(), - new DataflowWorkUnitClient(options, LOG), - options.as(StreamingDataflowWorkerOptions.class), - true, - new HotKeyLogger(), - Instant::now, - (threadName) -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build())); - } + private final DataflowMapTaskExecutorFactory mapTaskExecutorFactory; + private final HotKeyLogger hotKeyLogger; + // Periodic sender of debug information to the debug capture service. + private final DebugCapture.@Nullable Manager debugCaptureManager; + private ScheduledExecutorService refreshWorkTimer; + private ScheduledExecutorService statusPageTimer; + private ScheduledExecutorService globalWorkerUpdatesTimer; + private int retryLocallyDelayMs = 10000; + // Periodically fires a global config request to dataflow service. Only used when windmill service + // is enabled. + private ScheduledExecutorService globalConfigRefreshTimer; + // Possibly overridden by streaming engine config. + private int maxWorkItemCommitBytes = Integer.MAX_VALUE; @VisibleForTesting StreamingDataflowWorker( @@ -393,6 +306,8 @@ public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( if (windmillServiceEnabled) { this.debugCaptureManager = new DebugCapture.Manager(options, statusPages.getDebugCapturePages()); + } else { + this.debugCaptureManager = null; } this.windmillShuffleBytesRead = pendingDeltaCounters.longSum( @@ -501,6 +416,71 @@ public void run() { LOG.debug("maxWorkItemCommitBytes: {}", maxWorkItemCommitBytes); } + /** Returns whether an exception was caused by a {@link OutOfMemoryError}. */ + private static boolean isOutOfMemoryError(Throwable t) { + while (t != null) { + if (t instanceof OutOfMemoryError) { + return true; + } + t = t.getCause(); + } + return false; + } + + private static MapTask parseMapTask(String input) throws IOException { + return Transport.getJsonFactory().fromString(input, MapTask.class); + } + + public static void main(String[] args) throws Exception { + JvmInitializers.runOnStartup(); + + DataflowWorkerHarnessHelper.initializeLogging(StreamingDataflowWorker.class); + DataflowWorkerHarnessOptions options = + DataflowWorkerHarnessHelper.initializeGlobalStateAndPipelineOptions( + StreamingDataflowWorker.class); + DataflowWorkerHarnessHelper.configureLogging(options); + checkArgument( + options.isStreaming(), + "%s instantiated with options indicating batch use", + StreamingDataflowWorker.class.getName()); + + checkArgument( + !DataflowRunner.hasExperiment(options, "beam_fn_api"), + "%s cannot be main() class with beam_fn_api enabled", + StreamingDataflowWorker.class.getSimpleName()); + + StreamingDataflowWorker worker = + StreamingDataflowWorker.fromDataflowWorkerHarnessOptions(options); + + // Use the MetricsLogger container which is used by BigQueryIO to periodically log process-wide + // metrics. + MetricsEnvironment.setProcessWideContainer(new MetricsLogger(null)); + + JvmInitializers.runBeforeProcessing(options); + worker.startStatusPages(); + worker.start(); + } + + public static StreamingDataflowWorker fromDataflowWorkerHarnessOptions( + DataflowWorkerHarnessOptions options) throws IOException { + + return new StreamingDataflowWorker( + Collections.emptyList(), + IntrinsicMapTaskExecutorFactory.defaultFactory(), + new DataflowWorkUnitClient(options, LOG), + options.as(StreamingDataflowWorkerOptions.class), + true, + new HotKeyLogger(), + Instant::now, + (threadName) -> + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(threadName).build())); + } + + private static void sleep(int millis) { + Uninterruptibles.sleepUninterruptibly(millis, TimeUnit.MILLISECONDS); + } + private int chooseMaximumNumberOfThreads() { if (options.getNumberOfWorkerHarnessThreads() != 0) { return options.getNumberOfWorkerHarnessThreads(); @@ -610,7 +590,7 @@ public void run() { + options.getWorkerId() + "_" + page.pageName() - + timestamp.toString()) + + timestamp) .replaceAll("/", "_")); writer = new PrintWriter(outputFile, UTF_8.name()); page.captureData(writer); @@ -738,10 +718,6 @@ private synchronized void addComputation( } } - private static void sleep(int millis) { - Uninterruptibles.sleepUninterruptibly(millis, TimeUnit.MILLISECONDS); - } - /** * If the computation is not yet known about, configuration for it will be fetched. This can still * return null if there is no configuration fetched for the computation. @@ -796,7 +772,7 @@ private void dispatchLoop() { inputDataWatermark, synchronizedProcessingTime, workItem, - /*getWorkStreamLatencies=*/ Collections.emptyList()); + /* getWorkStreamLatencies= */ Collections.emptyList()); } } } @@ -849,7 +825,6 @@ private void scheduleWorkItem( WindmillTimeUtils.windmillToHarnessWatermark(workItem.getOutputDataWatermark()); Preconditions.checkState( outputDataWatermark == null || !outputDataWatermark.isAfter(inputDataWatermark)); - Work scheduledWork = Work.create( workItem, @@ -862,7 +837,6 @@ private void scheduleWorkItem( outputDataWatermark, synchronizedProcessingTime, work)); - computationState.activateWork( ShardedKey.create(workItem.getKey(), workItem.getShardingKey()), scheduledWork); } @@ -933,13 +907,13 @@ private void process( final Windmill.WorkItem workItem = work.getWorkItem(); final String computationId = computationState.getComputationId(); final ByteString key = workItem.getKey(); - work.setState(Work.State.PROCESSING); + work.setState(State.PROCESSING); { - StringBuilder workIdBuilder = new StringBuilder(33); - workIdBuilder.append(Long.toHexString(workItem.getShardingKey())); - workIdBuilder.append('-'); - workIdBuilder.append(Long.toHexString(workItem.getWorkToken())); - DataflowWorkerLoggingMDC.setWorkId(workIdBuilder.toString()); + String workIdBuilder = + Long.toHexString(workItem.getShardingKey()) + + '-' + + Long.toHexString(workItem.getWorkToken()); + DataflowWorkerLoggingMDC.setWorkId(workIdBuilder); } DataflowWorkerLoggingMDC.setStageName(computationId); @@ -952,7 +926,7 @@ private void process( callFinalizeCallbacks(workItem); if (workItem.getSourceState().getOnlyFinalize()) { outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true)); - work.setState(Work.State.COMMIT_QUEUED); + work.setState(State.COMMIT_QUEUED); commitQueue.put(Commit.create(outputBuilder.build(), computationState, work)); return; } @@ -988,11 +962,13 @@ private void process( DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = new DataflowExecutionContext.DataflowExecutionStateTracker( ExecutionStateSampler.instance(), - stageInfo.getExecutionStateRegistry().getState( - NameContext.forStage(mapTask.getStageName()), - "other", - null, - ScopedProfiler.INSTANCE.emptyScope()), + stageInfo + .getExecutionStateRegistry() + .getState( + NameContext.forStage(mapTask.getStageName()), + "other", + null, + ScopedProfiler.INSTANCE.emptyScope()), stageInfo.getDeltaCounters(), options, computationId); @@ -1076,11 +1052,11 @@ private void process( workItem.getShardingKey(), workItem.getWorkToken(), () -> { - work.setState(Work.State.READING); + work.setState(State.READING); return new AutoCloseable() { @Override public void close() { - work.setState(Work.State.PROCESSING); + work.setState(State.PROCESSING); } }; }); @@ -1092,9 +1068,8 @@ public void close() { // // The coder type that will be present is: // WindowedValueCoder(TimerOrElementCoder(KvCoder)) - java.util.Optional> keyCoder = executionState.keyCoder(); - @Nullable - Object executionKey = + Optional> keyCoder = executionState.keyCoder(); + @Nullable Object executionKey = !keyCoder.isPresent() ? null : keyCoder.get().decode(key.newInput(), Coder.Context.OUTER); if (workItem.hasHotKeyInfo()) { @@ -1153,7 +1128,7 @@ public void close() { executionState = null; // Add the output to the commit queue. - work.setState(Work.State.COMMIT_QUEUED); + work.setState(State.COMMIT_QUEUED); outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions()); WorkItemCommitRequest commitRequest = outputBuilder.build(); @@ -1393,7 +1368,7 @@ private Commit batchCommitsToStream(CommitWorkStream commitStream) { Commit commit; try { if (commits < 5) { - commit = commitQueue.poll(10 - 2 * commits, TimeUnit.MILLISECONDS); + commit = commitQueue.poll(10 - 2L * commits, TimeUnit.MILLISECONDS); } else { commit = commitQueue.poll(); } @@ -1480,7 +1455,8 @@ private void getConfigFromWindmill(String computation) { addComputation( computationId, mapTask, - transformUserNameToStateFamilyByComputationId.get(computationId)); + transformUserNameToStateFamilyByComputationId.getOrDefault( + computationId, ImmutableMap.of())); } catch (IOException e) { LOG.warn("Parsing MapTask failed: {}", serializedMapTask); LOG.warn("Error: ", e); @@ -1498,13 +1474,12 @@ private void getConfigFromWindmill(String computation) { * @throws IOException if the RPC fails. */ private void getConfigFromDataflowService(@Nullable String computation) throws IOException { - Optional workItem; - if (computation != null) { - workItem = workUnitClient.getStreamingConfigWorkItem(computation); - } else { - workItem = workUnitClient.getGlobalStreamingConfigWorkItem(); - } - if (workItem == null || !workItem.isPresent() || workItem.get() == null) { + Optional workItem = + computation != null + ? workUnitClient.getStreamingConfigWorkItem(computation) + : workUnitClient.getGlobalStreamingConfigWorkItem(); + + if (!workItem.isPresent()) { return; } StreamingConfigTask config = workItem.get().getStreamingConfigTask(); @@ -1531,7 +1506,8 @@ private void getConfigFromDataflowService(@Nullable String computation) throws I addComputation( computationConfig.getComputationId(), mapTask, - computationConfig.getTransformUserNameToStateFamily()); + Optional.ofNullable(computationConfig.getTransformUserNameToStateFamily()) + .orElseGet(ImmutableMap::of)); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index a81d7273bfe4d..74d54b026b9ba 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -68,16 +68,13 @@ public ComputationState( BoundedQueueExecutor executor, Map transformUserNameToStateFamily, WindmillStateCache.ForComputation computationStateCache) { + Preconditions.checkNotNull(mapTask.getStageName()); + Preconditions.checkNotNull(mapTask.getSystemName()); this.computationId = computationId; this.mapTask = mapTask; this.executor = executor; - this.transformUserNameToStateFamily = - transformUserNameToStateFamily != null - ? ImmutableMap.copyOf(transformUserNameToStateFamily) - : ImmutableMap.of(); + this.transformUserNameToStateFamily = ImmutableMap.copyOf(transformUserNameToStateFamily); this.computationStateCache = computationStateCache; - Preconditions.checkNotNull(mapTask.getStageName()); - Preconditions.checkNotNull(mapTask.getSystemName()); } public String getComputationId() { @@ -96,7 +93,10 @@ public ConcurrentLinkedQueue getExecutionStateQueue() { return executionStateQueue; } - /** Mark the given shardedKey and work as active. */ + /** + * Mark the given {@link ShardedKey} and {@link Work} as active, and schedules execution of {@link + * Work} if there is no active {@link Work} for the {@link ShardedKey} already processing. + */ public boolean activateWork(ShardedKey shardedKey, Work work) { synchronized (activeWork) { Deque queue = activeWork.get(shardedKey); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java index 2ce3b18ac28c8..f2893f3e71914 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -24,7 +24,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; /** Bounded set of queues, with a maximum total weight. */ -public class WeightedBoundedQueue { +public final class WeightedBoundedQueue { private final LinkedBlockingQueue queue; private final int maxWeight; @@ -91,7 +91,7 @@ public void put(V value) { } /** Returns the current weight of the queue. */ - public int weight() { + public int queuedElementsWeight() { return maxWeight - limit.availablePermits(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java new file mode 100644 index 0000000000000..eca87526ea00d --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.concurrent.atomic.AtomicInteger; + +@RunWith(JUnit4.class) +public class WeightBoundedQueueTest { + private static final int MAX_WEIGHT = 10; + + @Test + public void testPut_hasCapacity() { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + int insertedValue = 1; + + queue.put(insertedValue); + + assertEquals(insertedValue, queue.queuedElementsWeight()); + assertEquals(1, queue.size()); + assertEquals(insertedValue, (int) queue.poll()); + } + + @Test + public void testPut_noCapacity() throws InterruptedException { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + // Insert value that takes all the capacity into the queue. + Thread thread1 = new Thread(() -> queue.put(MAX_WEIGHT)); + thread1.start(); + thread1.join(); + + // Try to insert another value into the queue. This will block since there is no capacity in the + // queue. + Thread thread2 = new Thread(() -> queue.put(MAX_WEIGHT)); + thread2.start(); + + // Should only see the first value in the queue, since the queue is at capacity. thread2 + // should be blocked. + assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(1, queue.size()); + + // Have another thread poll the queue, pulling off the only value inside and freeing up the + // capacity in the queue. + Thread thread3 = new Thread(queue::poll); + thread3.start(); + thread3.join(); + + // Wait for the thread2 which was previously blocked due to the queue being at capacity. + thread2.join(); + + assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(1, queue.size()); + } + + @Test + public void testPoll() { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + int insertedValue1 = 1; + int insertedValue2 = 2; + + queue.put(insertedValue1); + queue.put(insertedValue2); + + assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight()); + assertEquals(2, queue.size()); + assertEquals(insertedValue1, (int) queue.poll()); + assertEquals(1, queue.size()); + } + + @Test + public void testPoll_emptyQueue() { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + assertNull(queue.poll()); + } + + @Test + public void testTake() throws InterruptedException { + WeightedBoundedQueue queue = + WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + + AtomicInteger value = new AtomicInteger(); + // Should block until value is available + Thread takeThread = new Thread(() -> { + try { + value.set(queue.take()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + takeThread.start(); + + Thread putThread = new Thread(() -> queue.put(MAX_WEIGHT)); + putThread.start(); + putThread.join(); + + takeThread.join(); + + assertEquals(MAX_WEIGHT, value.get()); + } +}