Skip to content

Commit

Permalink
buffering across state flushes
Browse files Browse the repository at this point in the history
  • Loading branch information
je-ik committed Sep 30, 2024
1 parent 788a3f8 commit 1f54869
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import cz.o2.proxima.core.util.ExceptionUtils;
import cz.o2.proxima.core.util.Pair;
import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting;
import cz.o2.proxima.internal.com.google.common.base.MoreObjects;
import cz.o2.proxima.internal.com.google.common.base.Preconditions;
import cz.o2.proxima.internal.com.google.common.collect.Iterables;
import java.lang.annotation.Annotation;
Expand All @@ -41,6 +42,7 @@
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.description.annotation.AnnotationDescription;
import net.bytebuddy.description.modifier.FieldManifestation;
Expand Down Expand Up @@ -112,6 +114,7 @@
import org.checkerframework.checker.nullness.qual.NonNull;
import org.joda.time.Instant;

@Slf4j
public class ExternalStateExpander {

static final String EXPANDER_BUF_STATE_SPEC = "expanderBufStateSpec";
Expand Down Expand Up @@ -589,7 +592,7 @@ Builder<DoFn<InputT, OutputT>> addTimerFlushMethod(
.withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList()))
.intercept(
MethodDelegation.to(
new TimerFlushInterceptor<>(
new FlushTimerInterceptor<>(
doFn, processElement, expander, keyCoder, stateTag, nextFlushInstantFn)));
int i = 0;
for (Pair<AnnotationDescription, TypeDefinition> p : wrapperArgs) {
Expand Down Expand Up @@ -819,7 +822,7 @@ public void intercept(
}
}

private static class TimerFlushInterceptor<K, V> {
private static class FlushTimerInterceptor<K, V> {

private final DoFn<KV<K, V>, ?> doFn;
private final LinkedHashMap<String, BiFunction<Object, byte[], Iterable<StateValue>>>
Expand All @@ -830,7 +833,7 @@ private static class TimerFlushInterceptor<K, V> {
private final TupleTag<StateValue> stateTag;
private final UnaryFunction<Instant, Instant> nextFlushInstantFn;

TimerFlushInterceptor(
FlushTimerInterceptor(
DoFn<KV<K, V>, ?> doFn,
Method processElementMethod,
FlushTimerParameterExpander expander,
Expand All @@ -857,20 +860,19 @@ public void intercept(@This DoFn<KV<V, StateOrInput<V>>, ?> doFn, @AllArguments
Timer flushTimer = (Timer) args[args.length - 4];
Instant nextFlush = nextFlushInstantFn.apply(now);
Instant lastFlush = nextFlushState.read();
if (nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
boolean isNextScheduled =
nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE);
if (isNextScheduled) {
flushTimer.set(nextFlush);
nextFlushState.write(nextFlush);
}
@SuppressWarnings("unchecked")
BagState<TimestampedValue<KV<K, V>>> bufState =
(BagState<TimestampedValue<KV<K, V>>>) args[args.length - 2];
bufState
.read()
.forEach(
kv -> {
Object[] processArgs = expander.getProcessElementArgs(kv.getValue(), args);
ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs));
});
List<TimestampedValue<KV<K, V>>> pushedBackElements =
processBuffer(args, bufState.read(), MoreObjects.firstNonNull(lastFlush, now));
bufState.clear();
// if we have already processed state data
if (lastFlush != null) {
MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1];
OutputReceiver<StateValue> output = outputReceiver.get(stateTag);
Expand All @@ -879,9 +881,34 @@ public void intercept(@This DoFn<KV<V, StateOrInput<V>>, ?> doFn, @AllArguments
int i = 0;
for (BiFunction<Object, byte[], Iterable<StateValue>> f : stateReaders.values()) {
Object accessor = args[i++];
f.apply(accessor, keyBytes).forEach(output::output);
Iterable<StateValue> values = f.apply(accessor, keyBytes);
values.forEach(output::output);
}
}
List<TimestampedValue<KV<K, V>>> remaining =
processBuffer(
args,
pushedBackElements,
MoreObjects.firstNonNull(nextFlush, BoundedWindow.TIMESTAMP_MAX_VALUE));
remaining.forEach(bufState::add);
}

private List<TimestampedValue<KV<K, V>>> processBuffer(
Object[] args, Iterable<TimestampedValue<KV<K, V>>> buffer, Instant maxTs) {

List<TimestampedValue<KV<K, V>>> pushedBackElements = new ArrayList<>();
buffer.forEach(
kv -> {
if (kv.getTimestamp().isBefore(maxTs)) {
Object[] processArgs = expander.getProcessElementArgs(kv.getValue(), args);
ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs));
} else {
// return back to buffer
log.debug("Returning element {} to flush buffer", kv);
pushedBackElements.add(kv);
}
});
return pushedBackElements;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.apache.beam.sdk.values.TupleTag;
import org.joda.time.Instant;

public interface FlushTimerParameterExpander {
interface FlushTimerParameterExpander {

static FlushTimerParameterExpander of(
DoFn<?, ?> doFn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;

public class MethodCallUtils {
class MethodCallUtils {

static Object[] fromGenerators(
List<BiFunction<Object[], KV<?, ?>, Object>> generators, Object[] wrapperArgs) {
Expand All @@ -94,6 +94,9 @@ static Object[] fromGenerators(
TupleTag<?> mainTag,
Type outputType) {

// FIXME: interchange @Timestamp and transform OutputReceiver to output correct
// timestamp

List<BiFunction<Object[], KV<?, ?>, Object>> res = new ArrayList<>(argsMap.size());
List<TypeDefinition> wrapperParamsIds =
wrapperArgList.values().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import lombok.extern.slf4j.Slf4j;
import net.bytebuddy.description.annotation.AnnotationDescription;
import net.bytebuddy.description.annotation.AnnotationDescription.ForLoadedAnnotation;
import net.bytebuddy.description.type.TypeDefinition;
Expand All @@ -53,7 +54,7 @@
import org.apache.beam.sdk.values.TupleTag;
import org.joda.time.Instant;

public interface ProcessElementParameterExpander {
interface ProcessElementParameterExpander {

static ProcessElementParameterExpander of(
DoFn<?, ?> doFn,
Expand Down Expand Up @@ -108,44 +109,7 @@ private static UnaryFunction<Object[], Boolean> createProcessFn(
int elementPos = findParameter(wrapperArgs.keySet(), TypeId::isElement);
Preconditions.checkState(elementPos >= 0, "Missing @Element annotation on method %s", method);
Map<String, BiConsumer<Object, StateValue>> stateUpdaterMap = getStateUpdaters(doFn);
return args -> {
@SuppressWarnings("unchecked")
KV<?, StateOrInput<?>> elem = (KV<?, StateOrInput<?>>) args[elementPos];
Instant ts = (Instant) args[args.length - 5];
Timer flushTimer = (Timer) args[args.length - 4];
@SuppressWarnings("unchecked")
ValueState<Instant> finishedState = (ValueState<Instant>) args[args.length - 3];
flushTimer.set(stateWriteInstant);
boolean isState = Objects.requireNonNull(elem.getValue(), "elem").isState();
if (isState) {
StateValue state = elem.getValue().getState();
String stateName = state.getName();
// find state accessor
int statePos = findParameter(wrapperArgs.keySet(), a -> a.isState(stateName));
Preconditions.checkArgument(
statePos < method.getParameterCount(), "Missing state accessor for %s", stateName);
Object stateAccessor = args[statePos];
// find declaration of state to find coder
BiConsumer<Object, StateValue> updater = stateUpdaterMap.get(stateName);
Preconditions.checkArgument(
updater != null, "Missing updater for state %s in %s", stateName, stateUpdaterMap);
updater.accept(stateAccessor, state);
return false;
}
Instant nextFlush = finishedState.read();
boolean shouldBuffer =
nextFlush == null /* we have not finished reading state */
|| nextFlush.isBefore(ts) /* the timestamp if after next flush */;
if (shouldBuffer) {
// store to state
@SuppressWarnings("unchecked")
BagState<TimestampedValue<KV<?, ?>>> buffer =
(BagState<TimestampedValue<KV<?, ?>>>) args[args.length - 2];
buffer.add(TimestampedValue.of(KV.of(elem.getKey(), elem.getValue().getInput()), ts));
return false;
}
return true;
};
return new ProcessFn(elementPos, stateWriteInstant, wrapperArgs, method, stateUpdaterMap);
}

private static int findParameter(Collection<TypeId> args, Predicate<TypeId> predicate) {
Expand Down Expand Up @@ -227,4 +191,70 @@ static Pair<TypeId, Pair<AnnotationDescription, TypeDefinition>> transformProces
}
return Pair.of(typeId, Pair.of(annotation, parameterType));
}

@Slf4j
class ProcessFn implements UnaryFunction<Object[], Boolean> {
private final int elementPos;
private final Instant stateWriteInstant;
private final LinkedHashMap<TypeId, Pair<AnnotationDescription, TypeDefinition>> wrapperArgs;
private final Method method;
private final Map<String, BiConsumer<Object, StateValue>> stateUpdaterMap;

public ProcessFn(
int elementPos,
Instant stateWriteInstant,
LinkedHashMap<TypeId, Pair<AnnotationDescription, TypeDefinition>> wrapperArgs,
Method method,
Map<String, BiConsumer<Object, StateValue>> stateUpdaterMap) {
this.elementPos = elementPos;
this.stateWriteInstant = stateWriteInstant;
this.wrapperArgs = wrapperArgs;
this.method = method;
this.stateUpdaterMap = stateUpdaterMap;
}

@Override
public Boolean apply(Object[] args) {
@SuppressWarnings("unchecked")
KV<?, StateOrInput<?>> elem = (KV<?, StateOrInput<?>>) args[elementPos];
Instant ts = (Instant) args[args.length - 5];
Timer flushTimer = (Timer) args[args.length - 4];
@SuppressWarnings("unchecked")
ValueState<Instant> finishedState = (ValueState<Instant>) args[args.length - 3];
boolean isState = Objects.requireNonNull(elem.getValue(), "elem").isState();
if (isState) {
StateValue state = elem.getValue().getState();
String stateName = state.getName();
// find state accessor
int statePos = findParameter(wrapperArgs.keySet(), a -> a.isState(stateName));
Preconditions.checkArgument(
statePos < method.getParameterCount(), "Missing state accessor for %s", stateName);
Object stateAccessor = args[statePos];
// find declaration of state to find coder
BiConsumer<Object, StateValue> updater = stateUpdaterMap.get(stateName);
Preconditions.checkArgument(
updater != null, "Missing updater for state %s in %s", stateName, stateUpdaterMap);
updater.accept(stateAccessor, state);
return false;
}
Instant nextFlush = finishedState.read();
if (nextFlush == null) {
// set the initial timer
flushTimer.set(stateWriteInstant);
}
boolean shouldBuffer =
nextFlush == null /* we have not finished reading state */
|| nextFlush.isBefore(ts) /* the timestamp if after next flush */;
if (shouldBuffer) {
log.debug("Buffering element {} at {} with nextFlush {}", elem, ts, nextFlush);
// store to state
@SuppressWarnings("unchecked")
BagState<TimestampedValue<KV<?, ?>>> buffer =
(BagState<TimestampedValue<KV<?, ?>>>) args[args.length - 2];
buffer.add(TimestampedValue.of(KV.of(elem.getKey(), elem.getValue().getInput()), ts));
return false;
}
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.StateId;

public class TypeId {
class TypeId {

private static final TypeId TIMESTAMP_TYPE =
TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Timestamp.class).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.apache.beam.sdk.values.TypeDescriptors;
import org.jetbrains.annotations.NotNull;
import org.joda.time.Instant;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
Expand Down Expand Up @@ -208,16 +207,15 @@ public void testSimpleExpandWithStateStore() {
}

@Test
@Ignore
public void testStateWithElementEarly() throws CoderException {
Pipeline pipeline = createPipeline();
Instant now = new Instant(0);
PCollection<String> inputs =
pipeline.apply(
TestStream.create(StringUtf8Coder.of())
.addElements(TimestampedValue.of("1", now))
.advanceWatermarkTo(new Instant(0))
.addElements(TimestampedValue.of("3", now.plus(2)))
// the second timestamped value MUST not be part of the state produced at 1
.addElements(TimestampedValue.of("1", now), TimestampedValue.of("3", now.plus(2)))
.advanceWatermarkTo(new Instant(1))
.advanceWatermarkToInfinity());
PCollection<KV<Integer, String>> withKeys =
inputs.apply(
Expand Down

0 comments on commit 1f54869

Please sign in to comment.