diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java index e7096133d..7c04581e3 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java @@ -21,6 +21,7 @@ import cz.o2.proxima.core.util.ExceptionUtils; import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting; +import cz.o2.proxima.internal.com.google.common.base.MoreObjects; import cz.o2.proxima.internal.com.google.common.base.Preconditions; import cz.o2.proxima.internal.com.google.common.collect.Iterables; import java.lang.annotation.Annotation; @@ -41,6 +42,7 @@ import java.util.Set; import java.util.function.BiFunction; import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.annotation.AnnotationDescription; import net.bytebuddy.description.modifier.FieldManifestation; @@ -112,6 +114,7 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.Instant; +@Slf4j public class ExternalStateExpander { static final String EXPANDER_BUF_STATE_SPEC = "expanderBufStateSpec"; @@ -589,7 +592,7 @@ Builder> addTimerFlushMethod( .withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) .intercept( MethodDelegation.to( - new TimerFlushInterceptor<>( + new FlushTimerInterceptor<>( doFn, processElement, expander, keyCoder, stateTag, nextFlushInstantFn))); int i = 0; for (Pair p : wrapperArgs) { @@ -819,7 +822,7 @@ public void intercept( } } - private static class TimerFlushInterceptor { + private static class FlushTimerInterceptor { private final DoFn, ?> doFn; private final LinkedHashMap>> @@ -830,7 +833,7 @@ private static class TimerFlushInterceptor { private final TupleTag stateTag; private final UnaryFunction nextFlushInstantFn; - TimerFlushInterceptor( + FlushTimerInterceptor( DoFn, ?> doFn, Method processElementMethod, FlushTimerParameterExpander expander, @@ -857,20 +860,19 @@ public void intercept(@This DoFn>, ?> doFn, @AllArguments Timer flushTimer = (Timer) args[args.length - 4]; Instant nextFlush = nextFlushInstantFn.apply(now); Instant lastFlush = nextFlushState.read(); - if (nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + boolean isNextScheduled = + nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE); + if (isNextScheduled) { flushTimer.set(nextFlush); nextFlushState.write(nextFlush); } + @SuppressWarnings("unchecked") BagState>> bufState = (BagState>>) args[args.length - 2]; - bufState - .read() - .forEach( - kv -> { - Object[] processArgs = expander.getProcessElementArgs(kv.getValue(), args); - ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs)); - }); + List>> pushedBackElements = + processBuffer(args, bufState.read(), MoreObjects.firstNonNull(lastFlush, now)); bufState.clear(); + // if we have already processed state data if (lastFlush != null) { MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1]; OutputReceiver output = outputReceiver.get(stateTag); @@ -879,9 +881,34 @@ public void intercept(@This DoFn>, ?> doFn, @AllArguments int i = 0; for (BiFunction> f : stateReaders.values()) { Object accessor = args[i++]; - f.apply(accessor, keyBytes).forEach(output::output); + Iterable values = f.apply(accessor, keyBytes); + values.forEach(output::output); } } + List>> remaining = + processBuffer( + args, + pushedBackElements, + MoreObjects.firstNonNull(nextFlush, BoundedWindow.TIMESTAMP_MAX_VALUE)); + remaining.forEach(bufState::add); + } + + private List>> processBuffer( + Object[] args, Iterable>> buffer, Instant maxTs) { + + List>> pushedBackElements = new ArrayList<>(); + buffer.forEach( + kv -> { + if (kv.getTimestamp().isBefore(maxTs)) { + Object[] processArgs = expander.getProcessElementArgs(kv.getValue(), args); + ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs)); + } else { + // return back to buffer + log.debug("Returning element {} to flush buffer", kv); + pushedBackElements.add(kv); + } + }); + return pushedBackElements; } } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java index 21e4cf254..92355bf4f 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java @@ -38,7 +38,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; -public interface FlushTimerParameterExpander { +interface FlushTimerParameterExpander { static FlushTimerParameterExpander of( DoFn doFn, diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java index 64042786a..0d5a826df 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java @@ -68,7 +68,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; -public class MethodCallUtils { +class MethodCallUtils { static Object[] fromGenerators( List, Object>> generators, Object[] wrapperArgs) { @@ -94,6 +94,9 @@ static Object[] fromGenerators( TupleTag mainTag, Type outputType) { + // FIXME: interchange @Timestamp and transform OutputReceiver to output correct + // timestamp + List, Object>> res = new ArrayList<>(argsMap.size()); List wrapperParamsIds = wrapperArgList.values().stream() diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java index 6204ad200..4094baa69 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java @@ -34,6 +34,7 @@ import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Predicate; +import lombok.extern.slf4j.Slf4j; import net.bytebuddy.description.annotation.AnnotationDescription; import net.bytebuddy.description.annotation.AnnotationDescription.ForLoadedAnnotation; import net.bytebuddy.description.type.TypeDefinition; @@ -53,7 +54,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; -public interface ProcessElementParameterExpander { +interface ProcessElementParameterExpander { static ProcessElementParameterExpander of( DoFn doFn, @@ -108,44 +109,7 @@ private static UnaryFunction createProcessFn( int elementPos = findParameter(wrapperArgs.keySet(), TypeId::isElement); Preconditions.checkState(elementPos >= 0, "Missing @Element annotation on method %s", method); Map> stateUpdaterMap = getStateUpdaters(doFn); - return args -> { - @SuppressWarnings("unchecked") - KV> elem = (KV>) args[elementPos]; - Instant ts = (Instant) args[args.length - 5]; - Timer flushTimer = (Timer) args[args.length - 4]; - @SuppressWarnings("unchecked") - ValueState finishedState = (ValueState) args[args.length - 3]; - flushTimer.set(stateWriteInstant); - boolean isState = Objects.requireNonNull(elem.getValue(), "elem").isState(); - if (isState) { - StateValue state = elem.getValue().getState(); - String stateName = state.getName(); - // find state accessor - int statePos = findParameter(wrapperArgs.keySet(), a -> a.isState(stateName)); - Preconditions.checkArgument( - statePos < method.getParameterCount(), "Missing state accessor for %s", stateName); - Object stateAccessor = args[statePos]; - // find declaration of state to find coder - BiConsumer updater = stateUpdaterMap.get(stateName); - Preconditions.checkArgument( - updater != null, "Missing updater for state %s in %s", stateName, stateUpdaterMap); - updater.accept(stateAccessor, state); - return false; - } - Instant nextFlush = finishedState.read(); - boolean shouldBuffer = - nextFlush == null /* we have not finished reading state */ - || nextFlush.isBefore(ts) /* the timestamp if after next flush */; - if (shouldBuffer) { - // store to state - @SuppressWarnings("unchecked") - BagState>> buffer = - (BagState>>) args[args.length - 2]; - buffer.add(TimestampedValue.of(KV.of(elem.getKey(), elem.getValue().getInput()), ts)); - return false; - } - return true; - }; + return new ProcessFn(elementPos, stateWriteInstant, wrapperArgs, method, stateUpdaterMap); } private static int findParameter(Collection args, Predicate predicate) { @@ -227,4 +191,70 @@ static Pair> transformProces } return Pair.of(typeId, Pair.of(annotation, parameterType)); } + + @Slf4j + class ProcessFn implements UnaryFunction { + private final int elementPos; + private final Instant stateWriteInstant; + private final LinkedHashMap> wrapperArgs; + private final Method method; + private final Map> stateUpdaterMap; + + public ProcessFn( + int elementPos, + Instant stateWriteInstant, + LinkedHashMap> wrapperArgs, + Method method, + Map> stateUpdaterMap) { + this.elementPos = elementPos; + this.stateWriteInstant = stateWriteInstant; + this.wrapperArgs = wrapperArgs; + this.method = method; + this.stateUpdaterMap = stateUpdaterMap; + } + + @Override + public Boolean apply(Object[] args) { + @SuppressWarnings("unchecked") + KV> elem = (KV>) args[elementPos]; + Instant ts = (Instant) args[args.length - 5]; + Timer flushTimer = (Timer) args[args.length - 4]; + @SuppressWarnings("unchecked") + ValueState finishedState = (ValueState) args[args.length - 3]; + boolean isState = Objects.requireNonNull(elem.getValue(), "elem").isState(); + if (isState) { + StateValue state = elem.getValue().getState(); + String stateName = state.getName(); + // find state accessor + int statePos = findParameter(wrapperArgs.keySet(), a -> a.isState(stateName)); + Preconditions.checkArgument( + statePos < method.getParameterCount(), "Missing state accessor for %s", stateName); + Object stateAccessor = args[statePos]; + // find declaration of state to find coder + BiConsumer updater = stateUpdaterMap.get(stateName); + Preconditions.checkArgument( + updater != null, "Missing updater for state %s in %s", stateName, stateUpdaterMap); + updater.accept(stateAccessor, state); + return false; + } + Instant nextFlush = finishedState.read(); + if (nextFlush == null) { + // set the initial timer + flushTimer.set(stateWriteInstant); + } + boolean shouldBuffer = + nextFlush == null /* we have not finished reading state */ + || nextFlush.isBefore(ts) /* the timestamp if after next flush */; + if (shouldBuffer) { + log.debug("Buffering element {} at {} with nextFlush {}", elem, ts, nextFlush); + // store to state + @SuppressWarnings("unchecked") + BagState>> buffer = + (BagState>>) args[args.length - 2]; + buffer.add(TimestampedValue.of(KV.of(elem.getKey(), elem.getValue().getInput()), ts)); + return false; + } + return true; + } + } } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java index faad91e66..6c615c122 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java @@ -24,7 +24,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.StateId; -public class TypeId { +class TypeId { private static final TypeId TIMESTAMP_TYPE = TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Timestamp.class).build()); diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java index 67dbd1824..ba7a6f8ed 100644 --- a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java @@ -59,7 +59,6 @@ import org.apache.beam.sdk.values.TypeDescriptors; import org.jetbrains.annotations.NotNull; import org.joda.time.Instant; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -208,16 +207,15 @@ public void testSimpleExpandWithStateStore() { } @Test - @Ignore public void testStateWithElementEarly() throws CoderException { Pipeline pipeline = createPipeline(); Instant now = new Instant(0); PCollection inputs = pipeline.apply( TestStream.create(StringUtf8Coder.of()) - .addElements(TimestampedValue.of("1", now)) - .advanceWatermarkTo(new Instant(0)) - .addElements(TimestampedValue.of("3", now.plus(2))) + // the second timestamped value MUST not be part of the state produced at 1 + .addElements(TimestampedValue.of("1", now), TimestampedValue.of("3", now.plus(2))) + .advanceWatermarkTo(new Instant(1)) .advanceWatermarkToInfinity()); PCollection> withKeys = inputs.apply(