Skip to content

Commit

Permalink
remove MAX_WAIT_SECONDS and just use the deadline passed in, its alre…
Browse files Browse the repository at this point in the history
…ady 600 seconds anyway. Move DEFAULT_STREAM_RPC_DEADLINE_SECONDS to where it is being used and remove references in tests
  • Loading branch information
m-trieu committed Oct 10, 2024
1 parent 06836da commit 6df1adf
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
*/
public abstract class AbstractWindmillStream<RequestT, ResponseT> implements WindmillStream {

public static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300;
// Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce
// per-chunk overhead, and small enough that we can still perform granular flow-control.
protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,15 @@
@ThreadSafe
@Internal
public class GrpcWindmillStreamFactory implements StatusDataProvider {

private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300;
private static final Duration MIN_BACKOFF = Duration.millis(1);
private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30);
private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1;
private static final int DEFAULT_STREAMING_RPC_BATCH_LIMIT = Integer.MAX_VALUE;
private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1;
private static final int NO_HEALTH_CHECKS = -1;
private static final String NO_BACKEND_WORKER_TOKEN = "";
private static final long NO_DEADLINE = -1;
private static final long DEFAULT_DEADLINE_SECONDS =
AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2;
private static final String DISPATCHER_DEBUG_NAME = "Dispatcher";

private final JobHeader jobHeader;
Expand All @@ -99,7 +98,8 @@ private GrpcWindmillStreamFactory(
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier) {
Supplier<Duration> maxBackOffSupplier,
Set<AbstractWindmillStream<?, ?>> streamRegistry) {
this.jobHeader = jobHeader;
this.logEveryNStreamFailures = logEveryNStreamFailures;
this.streamingRpcBatchLimit = streamingRpcBatchLimit;
Expand All @@ -112,7 +112,7 @@ private GrpcWindmillStreamFactory(
.withInitialBackoff(MIN_BACKOFF)
.withMaxBackoff(maxBackOffSupplier.get())
.backoff());
this.streamRegistry = ConcurrentHashMap.newKeySet();
this.streamRegistry = streamRegistry;
this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
this.processHeartbeatResponses = processHeartbeatResponses;
this.streamIdGenerator = new AtomicLong();
Expand All @@ -127,7 +127,8 @@ static GrpcWindmillStreamFactory create(
boolean sendKeyedGetDataRequests,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier,
int healthCheckIntervalMillis) {
int healthCheckIntervalMillis,
Set<AbstractWindmillStream<?, ?>> streamRegistry) {
GrpcWindmillStreamFactory streamFactory =
new GrpcWindmillStreamFactory(
jobHeader,
Expand All @@ -136,7 +137,8 @@ static GrpcWindmillStreamFactory create(
windmillMessagesBetweenIsReadyChecks,
sendKeyedGetDataRequests,
processHeartbeatResponses,
maxBackOffSupplier);
maxBackOffSupplier,
streamRegistry);

if (healthCheckIntervalMillis >= 0) {
// Health checks are run on background daemon thread, which will only be cleaned up on JVM
Expand Down Expand Up @@ -173,14 +175,14 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) {
.setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT)
.setHealthCheckIntervalMillis(NO_HEALTH_CHECKS)
.setSendKeyedGetDataRequests(true)
.setProcessHeartbeatResponses(ignored -> {});
.setProcessHeartbeatResponses(ignored -> {})
.setStreamRegistry(ConcurrentHashMap.newKeySet());
}

private static <T extends AbstractStub<T>> T withDefaultDeadline(T stub) {
// Deadlines are absolute points in time, so generate a new one everytime this function is
// called.
return stub.withDeadlineAfter(
AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS);
return stub.withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS);
}

private static void printSummaryHtmlForWorker(
Expand All @@ -206,7 +208,7 @@ public GetWorkStream createGetWorkStream(
responseObserver -> withDefaultDeadline(stub).getWorkStream(responseObserver),
request,
grpcBackOff.get(),
newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
getWorkThrottleTimer,
Expand All @@ -226,7 +228,7 @@ public GetWorkStream createDirectGetWorkStream(
responseObserver -> connection.stub().getWorkStream(responseObserver),
request,
grpcBackOff.get(),
newStreamObserverFactory(NO_DEADLINE),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
getWorkThrottleTimer,
Expand All @@ -242,7 +244,7 @@ public GetDataStream createGetDataStream(
NO_BACKEND_WORKER_TOKEN,
responseObserver -> withDefaultDeadline(stub).getDataStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
getDataThrottleTimer,
Expand All @@ -259,7 +261,7 @@ public GetDataStream createDirectGetDataStream(
connection.backendWorkerToken(),
responseObserver -> connection.stub().getDataStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(NO_DEADLINE),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
getDataThrottleTimer,
Expand All @@ -276,7 +278,7 @@ public CommitWorkStream createCommitWorkStream(
NO_BACKEND_WORKER_TOKEN,
responseObserver -> withDefaultDeadline(stub).commitWorkStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
commitWorkThrottleTimer,
Expand All @@ -291,7 +293,7 @@ public CommitWorkStream createDirectCommitWorkStream(
connection.backendWorkerToken(),
responseObserver -> connection.stub().commitWorkStream(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(NO_DEADLINE),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
commitWorkThrottleTimer,
Expand All @@ -307,16 +309,17 @@ public GetWorkerMetadataStream createGetWorkerMetadataStream(
return GrpcGetWorkerMetadataStream.create(
responseObserver -> withDefaultDeadline(stub).getWorkerMetadata(responseObserver),
grpcBackOff.get(),
newStreamObserverFactory(DEFAULT_DEADLINE_SECONDS),
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
jobHeader,
getWorkerMetadataThrottleTimer,
onNewWindmillEndpoints);
}

private StreamObserverFactory newStreamObserverFactory(long deadline) {
return StreamObserverFactory.direct(deadline, windmillMessagesBetweenIsReadyChecks);
private StreamObserverFactory newStreamObserverFactory() {
return StreamObserverFactory.direct(
DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, windmillMessagesBetweenIsReadyChecks);
}

@Override
Expand Down Expand Up @@ -350,6 +353,8 @@ Builder setProcessHeartbeatResponses(

Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis);

Builder setStreamRegistry(Set<AbstractWindmillStream<?, ?>> streamRegistry);

GrpcWindmillStreamFactory build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@ThreadSafe
public final class DirectStreamObserver<T> implements StreamObserver<T> {
private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class);
private static final long MAX_WAIT_SECONDS = 600; // 10 minutes.
private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30;

private final Phaser isReadyNotifier;

Expand Down Expand Up @@ -125,13 +125,13 @@ public void onNext(T value) {
}

totalSecondsWaited += waitSeconds;
if (hasDeadlineExpired(totalSecondsWaited)) {
if (totalSecondsWaited > deadlineSeconds) {
String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited);
LOG.error(errorMessage);
throw new StreamObserverCancelledException(errorMessage, e);
}

if (totalSecondsWaited > 30) {
if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) {
LOG.info(
"Output channel stalled for {}s, outbound thread {}.",
totalSecondsWaited,
Expand Down Expand Up @@ -164,10 +164,6 @@ public void onCompleted() {
}
}

private boolean hasDeadlineExpired(long totalSecondsWaited) {
return totalSecondsWaited > (deadlineSeconds > 0 ? deadlineSeconds : MAX_WAIT_SECONDS);
}

private String constructStreamCancelledErrorMessage(long totalSecondsWaited) {
return deadlineSeconds > 0
? "Waited "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
Expand All @@ -30,15 +29,11 @@
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
Expand Down Expand Up @@ -69,7 +64,6 @@ public class GrpcCommitWorkStreamTest {

@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
private final Set<AbstractWindmillStream<?, ?>> streamRegistry = new HashSet<>();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private ManagedChannel inProcessChannel;

Expand Down Expand Up @@ -105,19 +99,11 @@ public void cleanUp() {

private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamStreamTestStub testStub) {
serviceRegistry.addService(testStub);
return GrpcCommitWorkStream.create(
"streamId",
responseObserver ->
CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)
.commitWorkStream(responseObserver),
FluentBackoff.DEFAULT.backoff(),
StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1),
streamRegistry,
1,
new ThrottleTimer(),
TEST_JOB_HEADER,
new AtomicLong(),
Integer.MAX_VALUE);
return (GrpcCommitWorkStream)
GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
.build()
.createCommitWorkStream(
CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), new ThrottleTimer());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,23 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException;
import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
Expand Down Expand Up @@ -71,7 +64,6 @@ public class GrpcGetDataStreamTest {

@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
private final Set<AbstractWindmillStream<?, ?>> streamRegistry = new HashSet<>();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private ManagedChannel inProcessChannel;

Expand All @@ -98,21 +90,12 @@ public void cleanUp() {

private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) {
serviceRegistry.addService(testStub);
return GrpcGetDataStream.create(
"streamId",
responseObserver ->
CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)
.getDataStream(responseObserver),
FluentBackoff.DEFAULT.backoff(),
StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1),
streamRegistry,
1,
new ThrottleTimer(),
TEST_JOB_HEADER,
new AtomicLong(),
Integer.MAX_VALUE,
false,
ignored -> {});
return (GrpcGetDataStream)
GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
.setSendKeyedGetDataRequests(false)
.build()
.createGetDataStream(
CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), new ThrottleTimer());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.verify;
Expand All @@ -30,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand All @@ -39,9 +39,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory;
import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
Expand Down Expand Up @@ -82,26 +80,23 @@ public class GrpcGetWorkerMetadataStreamTest {
private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest";
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
private final Set<AbstractWindmillStream<?, ?>> streamRegistry = new HashSet<>();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
private ManagedChannel inProcessChannel;
private GrpcGetWorkerMetadataStream stream;
private Set<AbstractWindmillStream<?, ?>> streamRegistry;

private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream(
GetWorkerMetadataTestStub getWorkerMetadataTestStub,
Consumer<WindmillEndpoints> endpointsConsumer) {
serviceRegistry.addService(getWorkerMetadataTestStub);
return GrpcGetWorkerMetadataStream.create(
responseObserver ->
CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel)
.getWorkerMetadata(responseObserver),
FluentBackoff.DEFAULT.backoff(),
StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1),
streamRegistry,
1, // logEveryNStreamFailures
TEST_JOB_HEADER,
new ThrottleTimer(),
endpointsConsumer);
return (GrpcGetWorkerMetadataStream)
GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
.setStreamRegistry(streamRegistry)
.build()
.createGetWorkerMetadataStream(
CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel),
new ThrottleTimer(),
endpointsConsumer);
}

@Before
Expand All @@ -124,6 +119,7 @@ public void setUp() throws IOException {
.setDirectEndpoint(IPV6_ADDRESS_1)
.setBackendWorkerToken("worker_token")
.build());
streamRegistry = ConcurrentHashMap.newKeySet();
}

@After
Expand Down

0 comments on commit 6df1adf

Please sign in to comment.