Skip to content

Commit

Permalink
address PR comments, remove unused IPv6 WindmillServiceAddress, renam…
Browse files Browse the repository at this point in the history
…e StreamingEngineConnectionsState to StreamingEngineBackends
  • Loading branch information
m-trieu committed Oct 10, 2024
1 parent 6df1adf commit de3b016
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress;
import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer;
import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
Expand All @@ -85,7 +84,6 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCache;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingRemoteStubFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.IsolationChannel;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
Expand All @@ -107,7 +105,6 @@
import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.util.construction.CoderTranslation;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheStats;
Expand Down Expand Up @@ -603,20 +600,12 @@ private static void validateWorkerOptions(DataflowWorkerHarnessOptions options)

private static ChannelCachingStubFactory createStubFactory(
DataflowWorkerHarnessOptions workerOptions) {
Function<WindmillServiceAddress, ManagedChannel> channelFactory =
serviceAddress ->
remoteChannel(
serviceAddress, workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec());
ChannelCache channelCache =
return ChannelCachingRemoteStubFactory.create(
workerOptions.getGcpCredential(),
ChannelCache.create(
serviceAddress ->
// IsolationChannel will create and manage separate RPC channels to the same
// serviceAddress via calling the channelFactory, else just directly return the
// RPC channel.
workerOptions.getUseWindmillIsolatedChannels()
? IsolationChannel.create(() -> channelFactory.apply(serviceAddress))
: channelFactory.apply(serviceAddress));
return ChannelCachingRemoteStubFactory.create(workerOptions.getGcpCredential(), channelCache);
remoteChannel(
serviceAddress, workerOptions.getWindmillServiceRpcChannelAliveTimeoutSec())));
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet;

import java.util.Collection;
import java.util.List;
import java.io.Closeable;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -99,7 +100,7 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
private final Object metadataLock;

/** Writes are guarded by synchronization, reads are lock free. */
private final AtomicReference<StreamingEngineConnectionState> connections;
private final AtomicReference<StreamingEngineBackends> backends;

@GuardedBy("this")
private long activeMetadataVersion;
Expand Down Expand Up @@ -129,7 +130,7 @@ private FanOutStreamingEngineWorkerHarness(
this.started = false;
this.streamFactory = streamFactory;
this.workItemScheduler = workItemScheduler;
this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY);
this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY);
this.channelCachingStubFactory = channelCachingStubFactory;
this.dispatcherClient = dispatcherClient;
this.getWorkerMetadataThrottleTimer = new ThrottleTimer();
Expand Down Expand Up @@ -212,32 +213,23 @@ public synchronized void start() {
}

public ImmutableSet<HostAndPort> currentWindmillEndpoints() {
return connections.get().windmillStreams().keySet().stream()
return backends.get().windmillStreams().keySet().stream()
.map(Endpoint::directEndpoint)
.filter(Optional::isPresent)
.map(Optional::get)
.filter(
windmillServiceAddress ->
windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6)
.map(
windmillServiceAddress ->
windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS
? windmillServiceAddress.gcpServiceAddress()
: windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress())
.map(WindmillServiceAddress::getServiceAddress)
.collect(toImmutableSet());
}

/**
* Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link
* GetDataStream} pointing to dispatcher.
* Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link
* NoSuchElementException} if one is not found.
*/
private GetDataStream getGlobalDataStream(String globalDataKey) {
return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey))
.map(Supplier::get)
.orElseGet(
() ->
streamFactory.createGetDataStream(
dispatcherClient.getWindmillServiceStub(), new ThrottleTimer()));
return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey))
.map(GlobalDataStreamSender::get)
.orElseThrow(
() -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey));
}

@VisibleForTesting
Expand Down Expand Up @@ -270,105 +262,110 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
}
}

long previousMetadataVersion = activeMetadataVersion;
LOG.debug(
"Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}",
newWindmillEndpoints,
previousMetadataVersion,
activeMetadataVersion);
closeStaleStreams(
newWindmillEndpoints.windmillEndpoints(), connections.get().windmillStreams());
activeMetadataVersion,
newWindmillEndpoints.version());
closeStaleStreams(newWindmillEndpoints);
ImmutableMap<Endpoint, WindmillStreamSender> newStreams =
createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join();
StreamingEngineConnectionState newConnectionsState =
StreamingEngineConnectionState.builder()
StreamingEngineBackends newBackends =
StreamingEngineBackends.builder()
.setWindmillStreams(newStreams)
.setGlobalDataStreams(
createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints()))
.build();
connections.set(newConnectionsState);
backends.set(newBackends);
getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget);
activeMetadataVersion = newWindmillEndpoints.version();
}

/** Close the streams that are no longer valid asynchronously. */
@SuppressWarnings("FutureReturnValueIgnored")
private void closeStaleStreams(
Collection<Endpoint> newWindmillConnections,
ImmutableMap<Endpoint, WindmillStreamSender> currentStreams) {
currentStreams.entrySet().stream()
private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) {
StreamingEngineBackends currentBackends = backends.get();
ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams =
currentBackends.windmillStreams();
currentWindmillStreams.entrySet().stream()
.filter(
connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey()))
connectionAndStream ->
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
.forEach(
entry ->
CompletableFuture.runAsync(
() -> {
LOG.debug("Closing streams to {}", entry);
try {
entry.getValue().closeAllStreams();
entry
.getKey()
.directEndpoint()
.ifPresent(channelCachingStubFactory::remove);
LOG.debug("Successfully closed streams to {}", entry);
} catch (Exception e) {
LOG.error("Error closing streams to {}", entry);
}
},
() -> closeStreamSender(entry.getKey(), entry.getValue()),
windmillStreamManager));

Set<Endpoint> newGlobalDataEndpoints =
new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
currentBackends.globalDataStreams().values().stream()
.filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
.forEach(
sender ->
CompletableFuture.runAsync(
() -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager));
}

private synchronized CompletableFuture<ImmutableMap<Endpoint, WindmillStreamSender>>
createAndStartNewStreams(Collection<Endpoint> newWindmillConnections) {
ImmutableMap<Endpoint, WindmillStreamSender> currentStreams =
connections.get().windmillStreams();
CompletionStage<List<Pair<Endpoint, WindmillStreamSender>>> connectionAndSenderFuture =
MoreFutures.allAsList(
newWindmillConnections.stream()
.map(
connection ->
MoreFutures.supplyAsync(
() ->
Pair.of(
connection,
Optional.ofNullable(currentStreams.get(connection))
.orElseGet(
() -> createAndStartWindmillStreamSender(connection))),
windmillStreamManager))
.collect(Collectors.toList()));
private void closeStreamSender(Endpoint endpoint, Closeable sender) {
LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender);
try {
sender.close();
endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove);
LOG.debug("Successfully closed streams to {}", endpoint);
} catch (Exception e) {
LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, sender);
}
}

return connectionAndSenderFuture
private synchronized CompletableFuture<ImmutableMap<Endpoint, WindmillStreamSender>>
createAndStartNewStreams(ImmutableSet<Endpoint> newWindmillEndpoints) {
ImmutableMap<Endpoint, WindmillStreamSender> currentStreams = backends.get().windmillStreams();
return MoreFutures.allAsList(
newWindmillEndpoints.stream()
.map(endpoint -> getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams))
.collect(Collectors.toList()))
.thenApply(
connectionsAndSenders ->
connectionsAndSenders.stream()
.collect(toImmutableMap(Pair::getLeft, Pair::getRight)))
backends -> backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight)))
.toCompletableFuture();
}

private CompletionStage<Pair<Endpoint, WindmillStreamSender>>
getOrCreateWindmillStreamSenderFuture(
Endpoint endpoint, ImmutableMap<Endpoint, WindmillStreamSender> currentStreams) {
return MoreFutures.supplyAsync(
() ->
Pair.of(
endpoint,
Optional.ofNullable(currentStreams.get(endpoint))
.orElseGet(() -> createAndStartWindmillStreamSender(endpoint))),
windmillStreamManager);
}

/** Add up all the throttle times of all streams including GetWorkerMetadataStream. */
@Override
public long getAndResetThrottleTime() {
return connections.get().windmillStreams().values().stream()
return backends.get().windmillStreams().values().stream()
.map(WindmillStreamSender::getAndResetThrottleTime)
.reduce(0L, Long::sum)
+ getWorkerMetadataThrottleTimer.getAndResetThrottleTime();
}

public long currentActiveCommitBytes() {
return connections.get().windmillStreams().values().stream()
return backends.get().windmillStreams().values().stream()
.map(WindmillStreamSender::getCurrentActiveCommitBytes)
.reduce(0L, Long::sum);
}

@VisibleForTesting
StreamingEngineConnectionState getCurrentConnections() {
return connections.get();
StreamingEngineBackends currentBackends() {
return backends.get();
}

private ImmutableMap<String, Supplier<GetDataStream>> createNewGlobalDataStreams(
private ImmutableMap<String, GlobalDataStreamSender> createNewGlobalDataStreams(
ImmutableMap<String, Endpoint> newGlobalDataEndpoints) {
ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams =
connections.get().globalDataStreams();
ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams =
backends.get().globalDataStreams();
return newGlobalDataEndpoints.entrySet().stream()
.collect(
toImmutableMap(
Expand All @@ -377,21 +374,23 @@ private ImmutableMap<String, Supplier<GetDataStream>> createNewGlobalDataStreams
existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams)));
}

private Supplier<GetDataStream> existingOrNewGetDataStreamFor(
private GlobalDataStreamSender existingOrNewGetDataStreamFor(
Entry<String, Endpoint> keyedEndpoint,
ImmutableMap<String, Supplier<GetDataStream>> currentGlobalDataStreams) {
ImmutableMap<String, GlobalDataStreamSender> currentGlobalDataStreams) {
return checkNotNull(
currentGlobalDataStreams.getOrDefault(
keyedEndpoint.getKey(),
() ->
streamFactory.createGetDataStream(
createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer())));
new GlobalDataStreamSender(
() ->
streamFactory.createGetDataStream(
createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()),
keyedEndpoint.getValue())));
}

private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint connection) {
private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) {
WindmillStreamSender windmillStreamSender =
WindmillStreamSender.create(
WindmillConnection.from(connection, this::createWindmillStub),
WindmillConnection.from(endpoint, this::createWindmillStub),
GetWorkRequest.newBuilder()
.setClientId(jobHeader.getClientId())
.setJobId(jobHeader.getJobId())
Expand All @@ -405,7 +404,7 @@ private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint connect
StreamGetDataClient.create(
getDataStream, this::getGlobalDataStream, getDataMetricTracker),
workCommitterFactory);
windmillStreamSender.startStreams();
windmillStreamSender.start();
return windmillStreamSender;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import com.google.common.base.Suppliers;
import java.io.Closeable;
import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream;
import org.apache.beam.sdk.annotations.Internal;

@Internal
@ThreadSafe
final class GlobalDataStreamSender implements Closeable, Supplier<GetDataStream> {
private final Endpoint endpoint;
private final Supplier<GetDataStream> delegate;
private volatile boolean started;

GlobalDataStreamSender(Supplier<GetDataStream> delegate, Endpoint endpoint) {
this.delegate = Suppliers.memoize(delegate::get);
this.started = false;
this.endpoint = endpoint;
}

@Override
public GetDataStream get() {
if (!started) {
started = true;
}
return delegate.get();
}

@Override
public void close() {
if (started) {
// get() may start the stream which is expensive, don't call it if the stream was never
// started.
delegate.get().shutdown();
}
}

Endpoint endpoint() {
return endpoint;
}
}
Loading

0 comments on commit de3b016

Please sign in to comment.