Skip to content

Commit

Permalink
break out Work from StreamingDataflowWorker
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Sep 19, 2023
1 parent 9127fb8 commit 18018c3
Show file tree
Hide file tree
Showing 3 changed files with 833 additions and 772 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -65,7 +64,6 @@
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.CloudObjects;
import org.apache.beam.runners.dataflow.worker.DataflowSystemMetrics.StreamingSystemCounterNames;
import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker.Work.State;
import org.apache.beam.runners.dataflow.worker.apiary.FixMultiOutputInfosOnParDoInstructions;
import org.apache.beam.runners.dataflow.worker.counters.Counter;
import org.apache.beam.runners.dataflow.worker.counters.CounterSet;
Expand Down Expand Up @@ -93,6 +91,7 @@
import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
import org.apache.beam.runners.dataflow.worker.streaming.StageInfo;
import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue;
import org.apache.beam.runners.dataflow.worker.streaming.Work;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ElementCounter;
Expand Down Expand Up @@ -850,117 +849,22 @@ private void scheduleWorkItem(
WindmillTimeUtils.windmillToHarnessWatermark(workItem.getOutputDataWatermark());
Preconditions.checkState(
outputDataWatermark == null || !outputDataWatermark.isAfter(inputDataWatermark));
Work work =
new Work(workItem, clock, getWorkStreamLatencies) {
@Override
public void run() {
process(
computationState,
inputDataWatermark,
outputDataWatermark,
synchronizedProcessingTime,
this);
}
};
computationState.activateWork(
ShardedKey.create(workItem.getKey(), workItem.getShardingKey()), work);
}

abstract static class Work implements Runnable {

enum State {
QUEUED(Windmill.LatencyAttribution.State.QUEUED),
PROCESSING(Windmill.LatencyAttribution.State.ACTIVE),
READING(Windmill.LatencyAttribution.State.READING),
COMMIT_QUEUED(Windmill.LatencyAttribution.State.COMMITTING),
COMMITTING(Windmill.LatencyAttribution.State.COMMITTING),
GET_WORK_IN_WINDMILL_WORKER(Windmill.LatencyAttribution.State.GET_WORK_IN_WINDMILL_WORKER),
GET_WORK_IN_TRANSIT_TO_DISPATCHER(
Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_DISPATCHER),
GET_WORK_IN_TRANSIT_TO_USER_WORKER(
Windmill.LatencyAttribution.State.GET_WORK_IN_TRANSIT_TO_USER_WORKER);

private final Windmill.LatencyAttribution.State latencyAttributionState;

private State(Windmill.LatencyAttribution.State latencyAttributionState) {
this.latencyAttributionState = latencyAttributionState;
}

Windmill.LatencyAttribution.State toLatencyAttributionState() {
return latencyAttributionState;
}
}

private final Windmill.WorkItem workItem;
private final Supplier<Instant> clock;
private final Instant startTime;
private Instant stateStartTime;
private State state;
private final Map<Windmill.LatencyAttribution.State, Duration> totalDurationPerState =
new EnumMap<>(Windmill.LatencyAttribution.State.class);

public Work(
Windmill.WorkItem workItem,
Supplier<Instant> clock,
Collection<LatencyAttribution> getWorkStreamLatencies) {
this.workItem = workItem;
this.clock = clock;
this.startTime = this.stateStartTime = clock.get();
this.state = State.QUEUED;
recordGetWorkStreamLatencies(getWorkStreamLatencies);
}

public Windmill.WorkItem getWorkItem() {
return workItem;
}

public Instant getStartTime() {
return startTime;
}

public State getState() {
return state;
}

public void setState(State state) {
Instant now = clock.get();
totalDurationPerState.compute(
this.state.toLatencyAttributionState(),
(s, d) -> new Duration(this.stateStartTime, now).plus(d == null ? Duration.ZERO : d));
this.state = state;
this.stateStartTime = now;
}

public Instant getStateStartTime() {
return stateStartTime;
}

private void recordGetWorkStreamLatencies(
Collection<LatencyAttribution> getWorkStreamLatencies) {
for (LatencyAttribution latency : getWorkStreamLatencies) {
totalDurationPerState.put(
latency.getState(), Duration.millis(latency.getTotalDurationMillis()));
}
}
Work scheduledWork =
Work.create(
workItem,
clock,
getWorkStreamLatencies,
work ->
process(
computationState,
inputDataWatermark,
outputDataWatermark,
synchronizedProcessingTime,
work));

public Collection<Windmill.LatencyAttribution> getLatencyAttributions() {
List<Windmill.LatencyAttribution> list = new ArrayList<>();
for (Windmill.LatencyAttribution.State state : Windmill.LatencyAttribution.State.values()) {
Duration duration = totalDurationPerState.getOrDefault(state, Duration.ZERO);
if (state == this.state.toLatencyAttributionState()) {
duration = duration.plus(new Duration(this.stateStartTime, clock.get()));
}
if (duration.equals(Duration.ZERO)) {
continue;
}
list.add(
Windmill.LatencyAttribution.newBuilder()
.setState(state)
.setTotalDurationMillis(duration.getMillis())
.build());
}
return list;
}
computationState.activateWork(
ShardedKey.create(workItem.getKey(), workItem.getShardingKey()), scheduledWork);
}

/**
Expand Down Expand Up @@ -1029,7 +933,7 @@ private void process(
final Windmill.WorkItem workItem = work.getWorkItem();
final String computationId = computationState.getComputationId();
final ByteString key = workItem.getKey();
work.setState(State.PROCESSING);
work.setState(Work.State.PROCESSING);
{
StringBuilder workIdBuilder = new StringBuilder(33);
workIdBuilder.append(Long.toHexString(workItem.getShardingKey()));
Expand All @@ -1048,7 +952,7 @@ private void process(
callFinalizeCallbacks(workItem);
if (workItem.getSourceState().getOnlyFinalize()) {
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
work.setState(State.COMMIT_QUEUED);
work.setState(Work.State.COMMIT_QUEUED);
commitQueue.put(Commit.create(outputBuilder.build(), computationState, work));
return;
}
Expand Down Expand Up @@ -1172,11 +1076,11 @@ private void process(
workItem.getShardingKey(),
workItem.getWorkToken(),
() -> {
work.setState(State.READING);
work.setState(Work.State.READING);
return new AutoCloseable() {
@Override
public void close() {
work.setState(State.PROCESSING);
work.setState(Work.State.PROCESSING);
}
};
});
Expand Down Expand Up @@ -1249,7 +1153,7 @@ public void close() {
executionState = null;

// Add the output to the commit queue.
work.setState(State.COMMIT_QUEUED);
work.setState(Work.State.COMMIT_QUEUED);
outputBuilder.addAllPerWorkItemLatencyAttributions(work.getLatencyAttributions());

WorkItemCommitRequest commitRequest = outputBuilder.build();
Expand Down Expand Up @@ -1411,7 +1315,7 @@ private void commitLoop() {
}
while (commit != null) {
ComputationState computationState = commit.computationState();
commit.work().setState(State.COMMITTING);
commit.work().setState(Work.State.COMMITTING);
Windmill.ComputationCommitWorkRequest.Builder computationRequestBuilder =
computationRequestMap.get(computationState);
if (computationRequestBuilder == null) {
Expand Down Expand Up @@ -1451,18 +1355,18 @@ private boolean addCommitToStream(Commit commit, CommitWorkStream commitStream)
final ComputationState state = commit.computationState();
final Windmill.WorkItemCommitRequest request = commit.request();
final int size = commit.getSize();
commit.work().setState(State.COMMITTING);
commit.work().setState(Work.State.COMMITTING);
activeCommitBytes.addAndGet(size);
if (commitStream.commitWorkItem(
state.computationId,
state.getComputationId(),
request,
(Windmill.CommitStatus status) -> {
if (status != Windmill.CommitStatus.OK) {
readerCache.invalidateReader(
WindmillComputationKey.create(
state.computationId, request.getKey(), request.getShardingKey()));
state.getComputationId(), request.getKey(), request.getShardingKey()));
stateCache
.forComputation(state.computationId)
.forComputation(state.getComputationId())
.invalidate(request.getKey(), request.getShardingKey());
}
activeCommitBytes.addAndGet(-size);
Expand All @@ -1475,7 +1379,7 @@ private boolean addCommitToStream(Commit commit, CommitWorkStream commitStream)
return true;
} else {
// Back out the stats changes since the commit wasn't consumed.
commit.work().setState(State.COMMIT_QUEUED);
commit.work().setState(Work.State.COMMIT_QUEUED);
activeCommitBytes.addAndGet(-size);
return false;
}
Expand Down
Loading

0 comments on commit 18018c3

Please sign in to comment.