Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

339 external state expander #928

Merged
merged 44 commits into from
Oct 4, 2024
Merged
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c0f7304
wip
je-ik Aug 27, 2024
098805f
Passing first test
je-ik Aug 27, 2024
fba229e
wip
je-ik Aug 28, 2024
35b5404
wip, test working
je-ik Aug 28, 2024
cbe3399
wip: passing second test
je-ik Aug 28, 2024
8cec61e
wip: reading from the state input
je-ik Aug 28, 2024
30019de
remapping working with reflection
je-ik Aug 28, 2024
ded3e6f
wip
je-ik Aug 30, 2024
69a716e
added expansion for singleoutput
je-ik Aug 30, 2024
36a64ab
state init test flaky
je-ik Aug 30, 2024
2e464e9
avoid cast to valuestate
je-ik Aug 30, 2024
39a44e6
test with flink
je-ik Aug 30, 2024
394b3ff
wip
je-ik Sep 4, 2024
120e5c2
wip
je-ik Sep 4, 2024
e47d72d
wip
je-ik Sep 4, 2024
ea946b0
tests still failing
je-ik Sep 4, 2024
f43aa4c
tests passing
je-ik Sep 4, 2024
fc442e3
before adding state flush to on window expiration
je-ik Sep 18, 2024
8daec2e
wip
je-ik Sep 19, 2024
1d40204
wip: some tests passing
je-ik Sep 20, 2024
54361aa
more tests passing
je-ik Sep 20, 2024
8826861
added state flush timer
je-ik Sep 20, 2024
1d7d786
wip
je-ik Sep 20, 2024
05ef15b
fixed some tests
je-ik Sep 23, 2024
682da17
wip
je-ik Sep 24, 2024
099f8f9
wip
je-ik Sep 24, 2024
115f34f
wip
je-ik Sep 24, 2024
66417e9
wip
je-ik Sep 24, 2024
e40e0e1
wip
je-ik Sep 24, 2024
a8a1c82
bag state to timestampedvalue
je-ik Sep 24, 2024
b0eab68
after adding FlushTimerParameterExpander
je-ik Sep 24, 2024
ace470d
working tests
je-ik Sep 25, 2024
6595bf1
wip
je-ik Sep 26, 2024
fba3416
coverage
je-ik Sep 30, 2024
788a3f8
sonar, missing one test
je-ik Sep 30, 2024
1f54869
buffering across state flushes
je-ik Sep 30, 2024
fe1e650
enable debug logging
je-ik Sep 30, 2024
3158b61
inject timestamp
je-ik Oct 1, 2024
e7b4aa0
fix class naming
je-ik Oct 1, 2024
25aa6d9
refactor
je-ik Oct 1, 2024
6677b22
add MethodInvoker
je-ik Oct 3, 2024
1dbdbf0
failing construction of subclass
je-ik Oct 3, 2024
c72e07a
use VoidMethodInvoker for delegating calls
je-ik Oct 4, 2024
6f6b0ad
sonar
je-ik Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
je-ik committed Sep 24, 2024
commit 099f8f9efaff45bdacef2ca217f60f1fee1edd23
Original file line number Diff line number Diff line change
@@ -108,6 +108,7 @@
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.reflect.TypeToken;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.joda.time.Instant;

public class ExternalStateExpander {

@@ -125,11 +126,15 @@ public class ExternalStateExpander {
*
* @param pipeline the Pipeline to expand
* @param inputs transform to read inputs
* @param stateWriteInstant the instant at which write of the last state occurred
* @param nextFlushInstantFn function that returns instant of next flush from current time
* @param stateSink transform to store outputs
*/
public static void expand(
Pipeline pipeline,
PTransform<PBegin, PCollection<KV<String, StateValue>>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {

validatePipeline(pipeline);
@@ -138,7 +143,9 @@ public static void expand(

// replace all MultiParDos
pipeline.replaceAll(
Collections.singletonList(statefulParMultiDoOverride(inputsMaterialized, stateSink)));
Collections.singletonList(
statefulParMultiDoOverride(
inputsMaterialized, stateWriteInstant, nextFlushInstantFn, stateSink)));
}

private static void validatePipeline(Pipeline pipeline) {
@@ -173,21 +180,26 @@ public void leavePipeline(Pipeline pipeline) {}

private static PTransformOverride statefulParMultiDoOverride(
PCollection<KV<String, StateValue>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {

return PTransformOverride.of(
application -> application.getTransform() instanceof ParDo.MultiOutput,
parMultiDoReplacementFactory(inputs, stateSink));
parMultiDoReplacementFactory(inputs, stateWriteInstant, nextFlushInstantFn, stateSink));
}

private static PTransformOverrideFactory parMultiDoReplacementFactory(
PCollection<KV<String, StateValue>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {

return new PTransformOverrideFactory() {
@Override
public PTransformReplacement getReplacementTransform(AppliedPTransform transform) {
return replaceParMultiDo(transform, inputs, stateSink);
return replaceParMultiDo(
transform, inputs, stateWriteInstant, nextFlushInstantFn, stateSink);
}

@SuppressWarnings("unchecked")
@@ -202,6 +214,8 @@ public Map<PCollection<?>, ReplacementOutput> mapOutputs(Map outputs, POutput ne
private static PTransformReplacement<PInput, POutput> replaceParMultiDo(
AppliedPTransform<PInput, POutput, ?> transform,
PCollection<KV<String, StateValue>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {

ParDo.MultiOutput<PInput, POutput> rawTransform =
@@ -228,6 +242,8 @@ private static PTransformReplacement<PInput, POutput> replaceParMultiDo(
transform.getOutputs().keySet().stream()
.filter(t -> !t.equals(mainOutputTag))
.collect(Collectors.toList())),
stateWriteInstant,
nextFlushInstantFn,
stateSink));
}

@@ -239,6 +255,8 @@ PTransform<PCollection<InputT>, PCollectionTuple> transformedParDo(
DoFn<KV<K, V>, OutputT> doFn,
TupleTag<OutputT> mainOutputTag,
TupleTagList otherOutputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {

return new PTransform<>() {
@@ -276,7 +294,13 @@ public PCollectionTuple expand(PCollection<InputT> input) {
PCollectionList.of(state).and(inputs).apply(Flatten.pCollections());
PCollectionTuple tuple =
flattened.apply(
ParDo.of(transformedDoFn(doFn, (KvCoder<K, V>) input.getCoder(), mainOutputTag))
ParDo.of(
transformedDoFn(
doFn,
(KvCoder<K, V>) input.getCoder(),
mainOutputTag,
stateWriteInstant,
nextFlushInstantFn))
.withOutputTags(mainOutputTag, otherOutputs.and(stateValueTupleTag)));
PCollection<StateValue> stateValuePCollection = tuple.get(stateValueTupleTag);
stateValuePCollection.apply(WithKeys.of(transformName)).apply(stateSink);
@@ -295,7 +319,11 @@ public PCollectionTuple expand(PCollection<InputT> input) {
@VisibleForTesting
static <K, V, InputT extends KV<K, StateOrInput<V>>, OutputT>
DoFn<InputT, OutputT> transformedDoFn(
DoFn<KV<K, V>, OutputT> doFn, KvCoder<K, V> inputCoder, TupleTag<OutputT> mainTag) {
DoFn<KV<K, V>, OutputT> doFn,
KvCoder<K, V> inputCoder,
TupleTag<OutputT> mainTag,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

@SuppressWarnings("unchecked")
Class<? extends DoFn<KV<K, V>, OutputT>> doFnClass =
@@ -357,7 +385,14 @@ DoFn<InputT, OutputT> transformedDoFn(

builder =
addProcessingMethods(
doFn, inputType, inputCoder.getKeyCoder(), mainTag, outputType, builder);
doFn,
inputType,
inputCoder.getKeyCoder(),
mainTag,
outputType,
stateWriteInstant,
nextFlushInstantFn,
builder);
Unloaded<DoFn<InputT, OutputT>> dynamicClass = builder.make();
// FIXME
ExceptionUtils.unchecked(() -> dynamicClass.saveIn(new File("/tmp/dynamic-debug")));
@@ -419,11 +454,14 @@ Builder<DoFn<InputT, OutputT>> addProcessingMethods(
Coder<K> keyCoder,
TupleTag<OutputT> mainTag,
Type outputType,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
Builder<DoFn<InputT, OutputT>> builder) {

builder = addProcessingMethod(doFn, DoFn.Setup.class, builder);
builder = addProcessingMethod(doFn, DoFn.StartBundle.class, builder);
builder = addProcessElementMethod(doFn, inputType, mainTag, outputType, builder);
builder =
addProcessElementMethod(doFn, inputType, mainTag, outputType, stateWriteInstant, builder);
builder = addProcessingMethod(doFn, DoFn.FinishBundle.class, builder);
builder = addProcessingMethod(doFn, DoFn.Teardown.class, builder);
builder = addOnWindowExpirationMethod(doFn, inputType, mainTag, builder);
@@ -435,9 +473,9 @@ Builder<DoFn<InputT, OutputT>> addProcessingMethods(
builder = addProcessingMethod(doFn, DoFn.NewWatermarkEstimator.class, builder);
builder = addProcessingMethod(doFn, DoFn.NewTracker.class, builder);
builder = addProcessingMethod(doFn, DoFn.OnTimer.class, builder);
builder = addTimerFlushMethod(doFn, inputType, keyCoder, stateValueTupleTag, builder);

// FIXME: timer callbacks
builder =
addTimerFlushMethod(
doFn, inputType, keyCoder, stateValueTupleTag, nextFlushInstantFn, builder);
return builder;
}

@@ -447,13 +485,15 @@ Builder<DoFn<InputT, OutputT>> addProcessElementMethod(
ParameterizedType inputType,
TupleTag<OutputT> mainTag,
Type outputType,
Instant stateWriteInstant,
Builder<DoFn<InputT, OutputT>> builder) {

Class<? extends Annotation> annotation = ProcessElement.class;
Method method = findMethod(doFn, annotation);
if (method != null) {
ProcessElementParameterExpander expander =
ProcessElementParameterExpander.of(doFn, method, inputType, mainTag, outputType);
ProcessElementParameterExpander.of(
doFn, method, inputType, mainTag, outputType, stateWriteInstant);
List<Pair<AnnotationDescription, TypeDefinition>> wrapperArgs = expander.getWrapperArgs();
MethodDefinition<DoFn<InputT, OutputT>> methodDefinition =
builder
@@ -522,7 +562,8 @@ Builder<DoFn<InputT, OutputT>> addTimerFlushMethod(
DoFn<KV<K, V>, OutputT> doFn,
ParameterizedType inputType,
Coder<K> keyCoder,
TupleTag stateTag,
TupleTag<StateValue> stateTag,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
Builder<DoFn<InputT, OutputT>> builder) {

List<Pair<Annotation, Type>> states =
@@ -535,17 +576,25 @@ Builder<DoFn<InputT, OutputT>> addTimerFlushMethod(
.filter(p -> p.getFirst() != null)
.collect(Collectors.toList());

List<Type> types = states.stream().map(Pair::getSecond).collect(Collectors.toList());
List<TypeDefinition> types =
states.stream()
.map(Pair::getSecond)
.map(t -> TypeDescription.Generic.Builder.of(t).build())
.collect(Collectors.toList());
// add parameter for key
types.add(inputType.getActualTypeArguments()[0]);
types.add(Timer.class);
types.add(DoFn.MultiOutputReceiver.class);
types.add(TypeDescription.Generic.Builder.of(inputType.getActualTypeArguments()[0]).build());
types.add(TypeDescription.ForLoadedType.of(Timer.class));
types.add(
TypeDescription.Generic.Builder.parameterizedType(ValueState.class, Boolean.class).build());
types.add(TypeDescription.ForLoadedType.of(DoFn.MultiOutputReceiver.class));

MethodDefinition<DoFn<InputT, OutputT>> methodDefinition =
builder
.defineMethod("expanderFlushTimer", void.class, Visibility.PUBLIC)
.withParameters(types)
.intercept(MethodDelegation.to(new TimerFlushInterceptor<>(doFn, keyCoder, stateTag)));
.intercept(
MethodDelegation.to(
new TimerFlushInterceptor<>(doFn, keyCoder, stateTag, nextFlushInstantFn)));

// retrieve parameter annotations and apply them
for (int i = 0; i < states.size(); i++) {
@@ -563,6 +612,12 @@ Builder<DoFn<InputT, OutputT>> addTimerFlushMethod(
AnnotationDescription.Builder.ofType(DoFn.TimerId.class)
.define("value", EXPANDER_TIMER_NAME)
.build());
methodDefinition =
methodDefinition.annotateParameter(
states.size() + 2,
AnnotationDescription.Builder.ofType(DoFn.StateId.class)
.define("value", EXPANDER_FINISHED_STATE_NAME)
.build());
return methodDefinition.annotateMethod(
AnnotationDescription.Builder.ofType(DoFn.OnTimer.class)
.define("value", EXPANDER_TIMER_NAME)
@@ -793,27 +848,40 @@ private static class TimerFlushInterceptor<K, V> {
stateReaders;
private final Coder<K> keyCoder;
private final TupleTag<StateValue> stateTag;
private final UnaryFunction<Instant, Instant> nextFlushInstantFn;

TimerFlushInterceptor(
DoFn<KV<K, V>, ?> doFn, Coder<K> keyCoder, TupleTag<StateValue> stateTag) {
DoFn<KV<K, V>, ?> doFn,
Coder<K> keyCoder,
TupleTag<StateValue> stateTag,
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

this.stateReaders = getStateReaders(doFn);
this.keyCoder = keyCoder;
this.stateTag = stateTag;
this.nextFlushInstantFn = nextFlushInstantFn;
}

@RuntimeType
public void intercept(@This DoFn<KV<V, StateOrInput<V>>, ?> doFn, @AllArguments Object[] args) {
@SuppressWarnings("unchecked")
K key = (K) args[args.length - 3];
MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1];
OutputReceiver<StateValue> output = outputReceiver.get(stateTag);
byte[] keyBytes =
ExceptionUtils.uncheckedFactory(() -> CoderUtils.encodeToByteArray(keyCoder, key));
int i = 0;
for (BiFunction<Object, byte[], Iterable<StateValue>> f : stateReaders.values()) {
Object accessor = args[i++];
System.err.println(" *** reading from " + accessor);
f.apply(accessor, keyBytes).forEach(output::output);
K key = (K) args[args.length - 4];
@SuppressWarnings("unchecked")
ValueState<Boolean> finishedState = (ValueState<Boolean>) args[args.length - 2];
if (finishedState.read() == null) {
finishedState.write(true);
} else {
Timer flushTimer = (Timer) args[args.length - 3];
flushTimer.set(nextFlushInstantFn.apply(flushTimer.getCurrentRelativeTime()));
MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1];
OutputReceiver<StateValue> output = outputReceiver.get(stateTag);
byte[] keyBytes =
ExceptionUtils.uncheckedFactory(() -> CoderUtils.encodeToByteArray(keyCoder, key));
int i = 0;
for (BiFunction<Object, byte[], Iterable<StateValue>> f : stateReaders.values()) {
Object accessor = args[i++];
f.apply(accessor, keyBytes).forEach(output::output);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import cz.o2.proxima.core.functional.BiConsumer;
import cz.o2.proxima.core.functional.UnaryFunction;
import cz.o2.proxima.core.util.Pair;
import cz.o2.proxima.internal.com.google.common.base.MoreObjects;
import cz.o2.proxima.internal.com.google.common.base.Preconditions;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
@@ -59,7 +60,8 @@ static ProcessElementParameterExpander of(
Method processElement,
ParameterizedType inputType,
TupleTag<?> mainTag,
Type outputType) {
Type outputType,
Instant stateWriteInstant) {

final LinkedHashMap<TypeId, Pair<Annotation, Type>> processArgs = extractArgs(processElement);
final LinkedHashMap<TypeId, Pair<AnnotationDescription, TypeDefinition>> wrapperArgs =
@@ -80,7 +82,7 @@ public Object[] getProcessElementArgs(Object[] wrapperArgs) {

@Override
public UnaryFunction<Object[], Boolean> getProcessFn() {
return createProcessFn(wrapperArgs, doFn, processElement);
return createProcessFn(wrapperArgs, doFn, processElement, stateWriteInstant);
}
};
}
@@ -100,7 +102,8 @@ public UnaryFunction<Object[], Boolean> getProcessFn() {
private static UnaryFunction<Object[], Boolean> createProcessFn(
LinkedHashMap<TypeId, Pair<AnnotationDescription, TypeDefinition>> wrapperArgs,
DoFn<?, ?> doFn,
Method method) {
Method method,
Instant stateWriteInstant) {

int elementPos = findParameter(wrapperArgs.keySet(), TypeId::isElement);
Preconditions.checkState(elementPos >= 0, "Missing @Element annotation on method %s", method);
@@ -109,9 +112,9 @@ private static UnaryFunction<Object[], Boolean> createProcessFn(
@SuppressWarnings("unchecked")
KV<?, StateOrInput<?>> elem = (KV<?, StateOrInput<?>>) args[elementPos];
Timer flushTimer = (Timer) args[args.length - 4];
// FIXME: set for particular timestamp
System.err.println(" *** " + elem + ", " + flushTimer.getCurrentRelativeTime());
flushTimer.set(new Instant(0));
@SuppressWarnings("unchecked")
ValueState<Boolean> finishedState = (ValueState<Boolean>) args[args.length - 3];
flushTimer.set(stateWriteInstant);
boolean isState = Objects.requireNonNull(elem.getValue(), "elem").isState();
if (isState) {
StateValue state = elem.getValue().getState();
@@ -128,9 +131,7 @@ private static UnaryFunction<Object[], Boolean> createProcessFn(
updater.accept(stateAccessor, state);
return false;
}
// FIXME: read this from state
// FIXME: set to 'true' for most tests to work now
boolean shouldBuffer = false;
boolean shouldBuffer = !MoreObjects.firstNonNull(finishedState.read(), false);
if (shouldBuffer) {
// store to state
@SuppressWarnings("unchecked")
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
@@ -53,6 +54,7 @@
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.jetbrains.annotations.NotNull;
import org.joda.time.Instant;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -80,7 +82,11 @@ public void testSimpleExpand() {
PCollection<Long> count = withKeys.apply(ParDo.of(getSumFn()));
PAssert.that(count).containsInAnyOrder(2L, 4L);
ExternalStateExpander.expand(
pipeline, Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), dummy());
pipeline,
Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())),
new Instant(0),
ign -> BoundedWindow.TIMESTAMP_MAX_VALUE,
dummy());
pipeline.run();
}

@@ -99,7 +105,11 @@ public void testSimpleExpandMultiOutput() {
.get(mainTag);
PAssert.that(count).containsInAnyOrder(2L, 4L);
ExternalStateExpander.expand(
pipeline, Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), dummy());
pipeline,
Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())),
new Instant(0),
ign -> BoundedWindow.TIMESTAMP_MAX_VALUE,
dummy());
pipeline.run();
}

@@ -121,7 +131,11 @@ public PCollection<Long> expand(PCollection<String> input) {
PCollection<Long> count = inputs.apply(transform);
PAssert.that(count).containsInAnyOrder(2L, 4L);
ExternalStateExpander.expand(
pipeline, Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), dummy());
pipeline,
Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())),
new Instant(0),
ign -> BoundedWindow.TIMESTAMP_MAX_VALUE,
dummy());
pipeline.run();
}

@@ -153,6 +167,8 @@ public void testSimpleExpandWithInitialState() throws CoderException {
"sum",
CoderUtils.encodeToByteArray(longCoder, 1L))))
.withCoder(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())),
new Instant(0),
ign -> BoundedWindow.TIMESTAMP_MAX_VALUE,
dummy());
pipeline.run();
}
@@ -171,6 +187,9 @@ public void testSimpleExpandWithStateStore() throws CoderException {
ExternalStateExpander.expand(
pipeline,
Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())),
// FIXME
new Instant(0),
ign -> BoundedWindow.TIMESTAMP_MAX_VALUE,
collectStates(states));
pipeline.run();
assertEquals(1, states.size());
@@ -214,7 +233,6 @@ public void process(
@Element KV<Integer, String> element,
@StateId("sum") ValueState<Long> sum) {

System.err.println(" *** " + ignored + ", " + element);
Preconditions.checkArgument(ignored instanceof OutputReceiver);
long current = MoreObjects.firstNonNull(sum.read(), 0L);
sum.write(current + Integer.parseInt(element.getValue()));