diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 17c5226fc01e2..27b399ea12875 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -68,7 +68,6 @@ */ public abstract class AbstractWindmillStream implements WindmillStream { - public static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index e0bba8501bdd2..9f6ac107a3b6f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -68,6 +68,8 @@ @ThreadSafe @Internal public class GrpcWindmillStreamFactory implements StatusDataProvider { + + private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30); private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1; @@ -75,9 +77,6 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider { private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1; private static final int NO_HEALTH_CHECKS = -1; private static final String NO_BACKEND_WORKER_TOKEN = ""; - private static final long NO_DEADLINE = -1; - private static final long DEFAULT_DEADLINE_SECONDS = - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2; private static final String DISPATCHER_DEBUG_NAME = "Dispatcher"; private final JobHeader jobHeader; @@ -99,7 +98,8 @@ private GrpcWindmillStreamFactory( int windmillMessagesBetweenIsReadyChecks, boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses, - Supplier maxBackOffSupplier) { + Supplier maxBackOffSupplier, + Set> streamRegistry) { this.jobHeader = jobHeader; this.logEveryNStreamFailures = logEveryNStreamFailures; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -112,7 +112,7 @@ private GrpcWindmillStreamFactory( .withInitialBackoff(MIN_BACKOFF) .withMaxBackoff(maxBackOffSupplier.get()) .backoff()); - this.streamRegistry = ConcurrentHashMap.newKeySet(); + this.streamRegistry = streamRegistry; this.sendKeyedGetDataRequests = sendKeyedGetDataRequests; this.processHeartbeatResponses = processHeartbeatResponses; this.streamIdGenerator = new AtomicLong(); @@ -127,7 +127,8 @@ static GrpcWindmillStreamFactory create( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses, Supplier maxBackOffSupplier, - int healthCheckIntervalMillis) { + int healthCheckIntervalMillis, + Set> streamRegistry) { GrpcWindmillStreamFactory streamFactory = new GrpcWindmillStreamFactory( jobHeader, @@ -136,7 +137,8 @@ static GrpcWindmillStreamFactory create( windmillMessagesBetweenIsReadyChecks, sendKeyedGetDataRequests, processHeartbeatResponses, - maxBackOffSupplier); + maxBackOffSupplier, + streamRegistry); if (healthCheckIntervalMillis >= 0) { // Health checks are run on background daemon thread, which will only be cleaned up on JVM @@ -173,14 +175,14 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { .setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT) .setHealthCheckIntervalMillis(NO_HEALTH_CHECKS) .setSendKeyedGetDataRequests(true) - .setProcessHeartbeatResponses(ignored -> {}); + .setProcessHeartbeatResponses(ignored -> {}) + .setStreamRegistry(ConcurrentHashMap.newKeySet()); } private static > T withDefaultDeadline(T stub) { // Deadlines are absolute points in time, so generate a new one everytime this function is // called. - return stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + return stub.withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); } private static void printSummaryHtmlForWorker( @@ -206,7 +208,7 @@ public GetWorkStream createGetWorkStream( responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver), request, grpcBackOff.get(), - newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, getWorkThrottleTimer, @@ -226,7 +228,7 @@ public GetWorkStream createDirectGetWorkStream( responseObserver -> connection.stub().getWorkStream(responseObserver), request, grpcBackOff.get(), - newStreamObserverFactory(NO_DEADLINE), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, getWorkThrottleTimer, @@ -242,7 +244,7 @@ public GetDataStream createGetDataStream( NO_BACKEND_WORKER_TOKEN, responseObserver -> withDefaultDeadline(stub).getDataStream(responseObserver), grpcBackOff.get(), - newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, getDataThrottleTimer, @@ -259,7 +261,7 @@ public GetDataStream createDirectGetDataStream( connection.backendWorkerToken(), responseObserver -> connection.stub().getDataStream(responseObserver), grpcBackOff.get(), - newStreamObserverFactory(NO_DEADLINE), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, getDataThrottleTimer, @@ -276,7 +278,7 @@ public CommitWorkStream createCommitWorkStream( NO_BACKEND_WORKER_TOKEN, responseObserver -> withDefaultDeadline(stub).commitWorkStream(responseObserver), grpcBackOff.get(), - newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, commitWorkThrottleTimer, @@ -291,7 +293,7 @@ public CommitWorkStream createDirectCommitWorkStream( connection.backendWorkerToken(), responseObserver -> connection.stub().commitWorkStream(responseObserver), grpcBackOff.get(), - newStreamObserverFactory(NO_DEADLINE), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, commitWorkThrottleTimer, @@ -307,7 +309,7 @@ public GetWorkerMetadataStream createGetWorkerMetadataStream( return GrpcGetWorkerMetadataStream.create( responseObserver -> withDefaultDeadline(stub).getWorkerMetadata(responseObserver), grpcBackOff.get(), - newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS), + newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, jobHeader, @@ -315,8 +317,9 @@ public GetWorkerMetadataStream createGetWorkerMetadataStream( onNewWindmillEndpoints); } - private StreamObserverFactory newStreamObserverFactory(long deadline) { - return StreamObserverFactory.direct(deadline, windmillMessagesBetweenIsReadyChecks); + private StreamObserverFactory newStreamObserverFactory() { + return StreamObserverFactory.direct( + DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, windmillMessagesBetweenIsReadyChecks); } @Override @@ -350,6 +353,8 @@ Builder setProcessHeartbeatResponses( Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis); + Builder setStreamRegistry(Set> streamRegistry); + GrpcWindmillStreamFactory build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 010d3d81e15ac..4d798e8d18ea0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -39,7 +39,7 @@ @ThreadSafe public final class DirectStreamObserver implements StreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); - private static final long MAX_WAIT_SECONDS = 600; // 10 minutes. + private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; private final Phaser isReadyNotifier; @@ -125,13 +125,13 @@ public void onNext(T value) { } totalSecondsWaited += waitSeconds; - if (hasDeadlineExpired(totalSecondsWaited)) { + if (totalSecondsWaited > deadlineSeconds) { String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited); LOG.error(errorMessage); throw new StreamObserverCancelledException(errorMessage, e); } - if (totalSecondsWaited > 30) { + if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) { LOG.info( "Output channel stalled for {}s, outbound thread {}.", totalSecondsWaited, @@ -164,10 +164,6 @@ public void onCompleted() { } } - private boolean hasDeadlineExpired(long totalSecondsWaited) { - return totalSecondsWaited > (deadlineSeconds > 0 ? deadlineSeconds : MAX_WAIT_SECONDS); - } - private String constructStreamCancelledErrorMessage(long totalSecondsWaited) { return deadlineSeconds > 0 ? "Waited " diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 11d208b8ca957..68c29f6fd3956 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -30,15 +29,11 @@ import java.util.HashSet; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; @@ -69,7 +64,6 @@ public class GrpcCommitWorkStreamTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final Set> streamRegistry = new HashSet<>(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; @@ -105,19 +99,11 @@ public void cleanUp() { private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamStreamTestStub testStub) { serviceRegistry.addService(testStub); - return GrpcCommitWorkStream.create( - "streamId", - responseObserver -> - CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel) - .commitWorkStream(responseObserver), - FluentBackoff.DEFAULT.backoff(), - StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), - streamRegistry, - 1, - new ThrottleTimer(), - TEST_JOB_HEADER, - new AtomicLong(), - Integer.MAX_VALUE); + return (GrpcCommitWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createCommitWorkStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), new ThrottleTimer()); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index 992477aaa31c4..ddd243275524d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -18,30 +18,23 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; -import java.util.HashSet; import java.util.List; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; @@ -71,7 +64,6 @@ public class GrpcGetDataStreamTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final Set> streamRegistry = new HashSet<>(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; @@ -98,21 +90,12 @@ public void cleanUp() { private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { serviceRegistry.addService(testStub); - return GrpcGetDataStream.create( - "streamId", - responseObserver -> - CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel) - .getDataStream(responseObserver), - FluentBackoff.DEFAULT.backoff(), - StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), - streamRegistry, - 1, - new ThrottleTimer(), - TEST_JOB_HEADER, - new AtomicLong(), - Integer.MAX_VALUE, - false, - ignored -> {}); + return (GrpcGetDataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setSendKeyedGetDataRequests(false) + .build() + .createGetDataStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), new ThrottleTimer()); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 6cfbabe2db8d4..9e6d3eb0d7bf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; @@ -30,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -39,9 +39,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; @@ -82,26 +80,23 @@ public class GrpcGetWorkerMetadataStreamTest { private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final Set> streamRegistry = new HashSet<>(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private GrpcGetWorkerMetadataStream stream; + private Set> streamRegistry; private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( GetWorkerMetadataTestStub getWorkerMetadataTestStub, Consumer endpointsConsumer) { serviceRegistry.addService(getWorkerMetadataTestStub); - return GrpcGetWorkerMetadataStream.create( - responseObserver -> - CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel) - .getWorkerMetadata(responseObserver), - FluentBackoff.DEFAULT.backoff(), - StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), - streamRegistry, - 1, // logEveryNStreamFailures - TEST_JOB_HEADER, - new ThrottleTimer(), - endpointsConsumer); + return (GrpcGetWorkerMetadataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setStreamRegistry(streamRegistry) + .build() + .createGetWorkerMetadataStream( + CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer(), + endpointsConsumer); } @Before @@ -124,6 +119,7 @@ public void setUp() throws IOException { .setDirectEndpoint(IPV6_ADDRESS_1) .setBackendWorkerToken("worker_token") .build()); + streamRegistry = ConcurrentHashMap.newKeySet(); } @After