Skip to content

Commit

Permalink
remove thread polling for new metadata; propogate metadata version to…
Browse files Browse the repository at this point in the history
… WindmillEndpoints and don't process any version that is older than the current version in FanOutStreamingEngineWorkerHarness
  • Loading branch information
m-trieu committed Sep 25, 2024
1 parent 08a9c1e commit 4464453
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;

import java.util.Collection;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
Expand All @@ -34,6 +34,7 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.CheckReturnValue;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub;
Expand Down Expand Up @@ -61,12 +62,11 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.checkerframework.checker.initialization.qual.UnderInitialization;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -81,10 +81,7 @@
public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness {
private static final Logger LOG =
LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class);
private static final String PUBLISH_NEW_WORKER_METADATA_THREAD_NAME =
"PublishNewWorkerMetadataThread";
private static final String CONSUME_NEW_WORKER_METADATA_THREAD_NAME =
"ConsumeNewWorkerMetadataThread";
private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = "WorkerMetadataConsumerThread";
private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d";

private final JobHeader jobHeader;
Expand All @@ -96,19 +93,21 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
private final GetWorkBudget totalGetWorkBudget;
private final ThrottleTimer getWorkerMetadataThrottleTimer;
private final Supplier<GetWorkerMetadataStream> getWorkerMetadataStream;
private final Queue<WindmillEndpoints> newWindmillEndpoints;
private final Function<WindmillStream.CommitWorkStream, WorkCommitter> workCommitterFactory;
private final ThrottlingGetDataMetricTracker getDataMetricTracker;
private final ExecutorService windmillStreamManager;
private final ExecutorService newWorkerMetadataPublisher;
private final ExecutorService newWorkerMetadataConsumer;
private final ExecutorService workerMetadataConsumer;
private final Object metadataLock;

/** Writes are guarded by synchronization, reads are lock free. */
private final AtomicReference<StreamingEngineConnectionState> connections;

private volatile boolean started;
@GuardedBy("metadataLock")
private long metadataVersion;

@GuardedBy("this")
private boolean started;

@SuppressWarnings("FutureReturnValueIgnored")
private FanOutStreamingEngineWorkerHarness(
JobHeader jobHeader,
GetWorkBudget totalGetWorkBudget,
Expand All @@ -131,30 +130,15 @@ private FanOutStreamingEngineWorkerHarness(
this.windmillStreamManager =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build());
this.newWorkerMetadataPublisher =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setNameFormat(PUBLISH_NEW_WORKER_METADATA_THREAD_NAME)
.build());
this.newWorkerMetadataConsumer =
this.workerMetadataConsumer =
Executors.newSingleThreadScheduledExecutor(
new ThreadFactoryBuilder()
.setNameFormat(CONSUME_NEW_WORKER_METADATA_THREAD_NAME)
.build());
this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1));
new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build());
this.getWorkBudgetDistributor = getWorkBudgetDistributor;
this.totalGetWorkBudget = totalGetWorkBudget;
this.getWorkerMetadataStream =
Suppliers.memoize(
() ->
streamFactory.createGetWorkerMetadataStream(
dispatcherClient.getWindmillMetadataServiceStubBlocking(),
getWorkerMetadataThrottleTimer,
endpoints ->
// Run this on a separate thread than the grpc stream thread.
newWorkerMetadataPublisher.submit(
() -> newWindmillEndpoints.add(endpoints))));
this.metadataVersion = Long.MIN_VALUE;
this.getWorkerMetadataStream = Suppliers.memoize(createGetWorkerMetadataStream()::get);
this.workCommitterFactory = workCommitterFactory;
this.metadataLock = new Object();
}

/**
Expand Down Expand Up @@ -216,7 +200,6 @@ public synchronized void start() {
Preconditions.checkState(!started, "StreamingEngineClient cannot start twice.");
// Starts the stream, this value is memoized.
getWorkerMetadataStream.get();
startWorkerMetadataConsumer();
started = true;
}

Expand Down Expand Up @@ -249,27 +232,51 @@ private GetDataStream getGlobalDataStream(String globalDataKey) {
dispatcherClient.getWindmillServiceStub(), new ThrottleTimer()));
}

private void startWorkerMetadataConsumer() {
newWorkerMetadataConsumer.execute(
() -> {
while (true) {
Optional.ofNullable(newWindmillEndpoints.poll())
.ifPresent(this::consumeWindmillWorkerEndpoints);
}
});
}

@VisibleForTesting
@Override
public synchronized void shutdown() {
Preconditions.checkState(started, "StreamingEngineClient never started.");
getWorkerMetadataStream.get().halfClose();
newWorkerMetadataPublisher.shutdownNow();
newWorkerMetadataConsumer.shutdownNow();
workerMetadataConsumer.shutdownNow();
channelCachingStubFactory.shutdown();
}

@SuppressWarnings("methodref.receiver.bound")
private Supplier<GetWorkerMetadataStream> createGetWorkerMetadataStream(
@UnderInitialization FanOutStreamingEngineWorkerHarness this) {
// Checker Framework complains about reference to "this" in the constructor since the instance
// is "UnderInitialization" here, which we pass as a lambda to GetWorkerMetadataStream for
// processing new worker metadata. Supplier.get() is only called in start(), after we have
// constructed the FanOutStreamingEngineWorkerHarness.
return () ->
checkNotNull(streamFactory)
.createGetWorkerMetadataStream(
checkNotNull(dispatcherClient).getWindmillMetadataServiceStubBlocking(),
checkNotNull(getWorkerMetadataThrottleTimer),
this::consumeWorkerMetadata);
}

private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) {
synchronized (metadataLock) {
// Only process versions greater than what we currently have to prevent double processing of
// metadata.
if (windmillEndpoints.version() > metadataVersion) {
metadataVersion = windmillEndpoints.version();
workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints));
}
}
}

private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) {
// Since this is run on a single threaded executor, multiple versions of the metadata maybe
// queued up while a previous version of the windmillEndpoints were being consumed. Only consume
// the endpoints if they are the most current version.
synchronized (metadataLock) {
if (newWindmillEndpoints.version() < metadataVersion) {
return;
}
}

LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints);
ImmutableMap<Endpoint, WindmillConnection> newWindmillConnections =
createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints());
Expand Down Expand Up @@ -400,7 +407,7 @@ private ImmutableMap<String, Supplier<GetDataStream>> createNewGlobalDataStreams
private Supplier<GetDataStream> existingOrNewGetDataStreamFor(
Entry<String, Endpoint> keyedEndpoint,
ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams) {
return Preconditions.checkNotNull(
return checkNotNull(
currentGlobalDataStreams.getOrDefault(
keyedEndpoint.getKey(),
() ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public static WindmillEndpoints from(
.collect(toImmutableList());

return WindmillEndpoints.builder()
.setVersion(workerMetadataResponseProto.getMetadataVersion())
.setGlobalDataEndpoints(globalDataServers)
.setWindmillEndpoints(windmillServers)
.build();
Expand Down Expand Up @@ -123,6 +124,8 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
directEndpointAddress.getHostAddress(), (int) endpointProto.getPort()));
}

public abstract long version();

/**
* Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key
* is a global data tag and the value is the endpoint where the data associated with the global
Expand Down Expand Up @@ -204,6 +207,8 @@ public abstract static class Builder {

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setVersion(long version);

public abstract Builder setGlobalDataEndpoints(
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ public final class GrpcGetWorkerMetadataStream
private final Consumer<WindmillEndpoints> serverMappingConsumer;
private final Object metadataLock;

@GuardedBy("metadataLock")
private long metadataVersion;

@GuardedBy("metadataLock")
private WorkerMetadataResponse latestResponse;

Expand All @@ -61,7 +58,6 @@ private GrpcGetWorkerMetadataStream(
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
JobHeader jobHeader,
long metadataVersion,
ThrottleTimer getWorkerMetadataThrottleTimer,
Consumer<WindmillEndpoints> serverMappingConsumer) {
super(
Expand All @@ -74,7 +70,6 @@ private GrpcGetWorkerMetadataStream(
logEveryNStreamFailures,
"");
this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build();
this.metadataVersion = metadataVersion;
this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer;
this.serverMappingConsumer = serverMappingConsumer;
this.latestResponse = WorkerMetadataResponse.getDefaultInstance();
Expand All @@ -89,7 +84,6 @@ public static GrpcGetWorkerMetadataStream create(
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
JobHeader jobHeader,
int metadataVersion,
ThrottleTimer getWorkerMetadataThrottleTimer,
Consumer<WindmillEndpoints> serverMappingUpdater) {
GrpcGetWorkerMetadataStream getWorkerMetadataStream =
Expand All @@ -100,7 +94,6 @@ public static GrpcGetWorkerMetadataStream create(
streamRegistry,
logEveryNStreamFailures,
jobHeader,
metadataVersion,
getWorkerMetadataThrottleTimer,
serverMappingUpdater);
getWorkerMetadataStream.startStream();
Expand All @@ -119,14 +112,13 @@ protected void onResponse(WorkerMetadataResponse response) {

/**
* Acquires the {@link #metadataLock} Returns {@link Optional<WindmillEndpoints>} if the
* metadataVersion in the response is not stale (older or equal to {@link #metadataVersion}), else
* returns empty {@link Optional}.
* metadataVersion in the response is not stale (older or equal to current {@link
* WorkerMetadataResponse#getMetadataVersion()}), else returns empty {@link Optional}.
*/
private Optional<WindmillEndpoints> extractWindmillEndpointsFrom(
WorkerMetadataResponse response) {
synchronized (metadataLock) {
if (response.getMetadataVersion() > metadataVersion) {
this.metadataVersion = response.getMetadataVersion();
if (response.getMetadataVersion() > latestResponse.getMetadataVersion()) {
this.latestResponse = response;
return Optional.of(WindmillEndpoints.from(response));
} else {
Expand All @@ -136,7 +128,7 @@ private Optional<WindmillEndpoints> extractWindmillEndpointsFrom(
"Received metadata version={}; Current metadata version={}. "
+ "Skipping update because received stale metadata",
response.getMetadataVersion(),
this.metadataVersion);
latestResponse.getMetadataVersion());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ public GetWorkerMetadataStream createGetWorkerMetadataStream(
streamRegistry,
logEveryNStreamFailures,
jobHeader,
0,
getWorkerMetadataThrottleTimer,
onNewWindmillEndpoints);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ public class GrpcGetWorkerMetadataStreamTest {

private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream(
GetWorkerMetadataTestStub getWorkerMetadataTestStub,
int metadataVersion,
Consumer<WindmillEndpoints> endpointsConsumer) {
serviceRegistry.addService(getWorkerMetadataTestStub);
return GrpcGetWorkerMetadataStream.create(
Expand All @@ -101,7 +100,6 @@ private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream(
streamRegistry,
1, // logEveryNStreamFailures
TEST_JOB_HEADER,
metadataVersion,
new ThrottleTimer(),
endpointsConsumer);
}
Expand Down Expand Up @@ -146,8 +144,7 @@ public void testGetWorkerMetadata() {
new TestWindmillEndpointsConsumer();
GetWorkerMetadataTestStub testStub =
new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver());
int metadataVersion = -1;
stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer);
stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer);
testStub.injectWorkerMetadata(mockResponse);

assertThat(testWindmillEndpointsConsumer.globalDataEndpoints.keySet())
Expand Down Expand Up @@ -175,8 +172,7 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() {

GetWorkerMetadataTestStub testStub =
new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver());
int metadataVersion = 0;
stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer);
stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer);
testStub.injectWorkerMetadata(initialResponse);

List<WorkerMetadataResponse.Endpoint> newDirectPathEndpoints =
Expand Down Expand Up @@ -222,8 +218,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() {
Mockito.spy(new TestWindmillEndpointsConsumer());
GetWorkerMetadataTestStub testStub =
new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver());
int metadataVersion = 0;
stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer);
stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer);
testStub.injectWorkerMetadata(freshEndpoints);

List<WorkerMetadataResponse.Endpoint> staleDirectPathEndpoints =
Expand Down Expand Up @@ -252,7 +247,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() {
public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() {
GetWorkerMetadataTestStub testStub =
new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver());
stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer());
stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer());
testStub.injectWorkerMetadata(
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
Expand All @@ -270,7 +265,7 @@ public void testSendHealthCheck() {
TestGetWorkMetadataRequestObserver requestObserver =
Mockito.spy(new TestGetWorkMetadataRequestObserver());
GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver);
stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer());
stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer());
stream.sendHealthCheck();

verify(requestObserver).onNext(WorkerMetadataRequest.getDefaultInstance());
Expand Down

0 comments on commit 4464453

Please sign in to comment.