diff --git a/.github/workflows/gradle.sh b/.github/workflows/gradle.sh index 1a8215a63..62e6e0c14 100755 --- a/.github/workflows/gradle.sh +++ b/.github/workflows/gradle.sh @@ -27,10 +27,15 @@ fi ./gradlew publishToMavenLocal -Pvendor -PnoSigning +GRADLE_BUILD_ARGS="" +if [[ ! -z $RUNNER_DEBUG ]]; then + GRADLE_BUILD_ARGS="--info" +fi + if [[ "${IS_PR}" != "false" ]] || [[ "${BRANCH}" == "master" ]]; then ./gradlew spotlessCheck \ - && ./gradlew build -x test \ - && ./gradlew test -Pwith-coverage \ + && ./gradlew build -x test ${GRADLE_BUILD_ARGS} \ + && ./gradlew test -Pwith-coverage ${GRADLE_BUILD_ARGS} \ && JAVA_HOME=${JAVA_HOME_17_X64} ./gradlew sonar --no-parallel exit $? fi diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java new file mode 100644 index 000000000..5f80db317 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java @@ -0,0 +1,963 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; + +import cz.o2.proxima.core.functional.UnaryFunction; +import cz.o2.proxima.core.util.ExceptionUtils; +import cz.o2.proxima.core.util.Pair; +import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting; +import cz.o2.proxima.internal.com.google.common.base.MoreObjects; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import cz.o2.proxima.internal.com.google.common.collect.Iterables; +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.modifier.FieldManifestation; +import net.bytebuddy.description.modifier.Visibility; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.description.type.TypeDescription.Generic; +import net.bytebuddy.dynamic.DynamicType.Builder; +import net.bytebuddy.dynamic.DynamicType.Builder.MethodDefinition; +import net.bytebuddy.dynamic.DynamicType.Unloaded; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.FieldAccessor; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.Implementation.Composable; +import net.bytebuddy.implementation.MethodCall; +import net.bytebuddy.implementation.MethodDelegation; +import net.bytebuddy.implementation.bind.annotation.AllArguments; +import net.bytebuddy.implementation.bind.annotation.RuntimeType; +import net.bytebuddy.implementation.bind.annotation.This; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement; +import org.apache.beam.sdk.runners.TransformHierarchy; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; +import org.apache.beam.sdk.transforms.DoFn.StateId; +import org.apache.beam.sdk.transforms.DoFn.TimerId; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.ByteBuddyUtils; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.construction.ReplacementOutputs; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.reflect.TypeToken; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +@Slf4j +public class ExternalStateExpander { + + static final String EXPANDER_BUF_STATE_SPEC = "expanderBufStateSpec"; + static final String EXPANDER_BUF_STATE_NAME = "_expanderBuf"; + static final String EXPANDER_FLUSH_STATE_SPEC = "expanderFlushStateSpec"; + static final String EXPANDER_FLUSH_STATE_NAME = "_expanderFlush"; + static final String EXPANDER_TIMER_SPEC = "expanderTimerSpec"; + static final String EXPANDER_TIMER_NAME = "_expanderTimer"; + static final String DELEGATE_FIELD_NAME = "delegate"; + + static final TupleTag STATE_TUPLE_TAG = new StateTupleTag() {}; + + /** + * Expand the given @{link Pipeline} to support external state store and restore + * + * @param pipeline the Pipeline to expand + * @param inputs transform to read inputs + * @param stateWriteInstant the instant at which write of the last state occurred + * @param nextFlushInstantFn function that returns instant of next flush from current time + * @param stateSink transform to store outputs + */ + public static Pipeline expand( + Pipeline pipeline, + PTransform>> inputs, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn, + PTransform>, PDone> stateSink) { + + validatePipeline(pipeline); + pipeline.getCoderRegistry().registerCoderForClass(StateValue.class, StateValue.coder()); + PCollection> inputsMaterialized = pipeline.apply(inputs); + // replace all MultiParDos + pipeline.replaceAll( + Collections.singletonList( + statefulParMultiDoOverride(inputsMaterialized, stateWriteInstant, nextFlushInstantFn))); + // collect all StateValues + List>> stateValues = new ArrayList<>(); + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + @Override + public void visitPrimitiveTransform(TransformHierarchy.Node node) { + if (node.getTransform() instanceof ParDo.MultiOutput) { + node.getOutputs().entrySet().stream() + .filter(e -> e.getKey() instanceof StateTupleTag) + .map(Entry::getValue) + .findAny() + .ifPresent( + p -> + stateValues.add( + Pair.of(node.getFullName(), (PCollection) p))); + } + } + }); + if (!stateValues.isEmpty()) { + PCollectionList> list = PCollectionList.empty(pipeline); + for (Pair> p : stateValues) { + PCollection> mapped = p.getSecond().apply(WithKeys.of(p.getFirst())); + list = list.and(mapped); + } + list.apply(Flatten.pCollections()).apply(stateSink); + } + + return pipeline; + } + + private static void validatePipeline(Pipeline pipeline) { + // check that all nodes have unique names + Set names = new HashSet<>(); + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + + @Override + public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) { + Preconditions.checkState(names.add(node.getFullName())); + return CompositeBehavior.ENTER_TRANSFORM; + } + + @Override + public void visitPrimitiveTransform(TransformHierarchy.Node node) { + Preconditions.checkState(names.add(node.getFullName())); + } + }); + } + + private static PTransformOverride statefulParMultiDoOverride( + PCollection> inputs, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn) { + + return PTransformOverride.of( + application -> application.getTransform() instanceof ParDo.MultiOutput, + parMultiDoReplacementFactory(inputs, stateWriteInstant, nextFlushInstantFn)); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static PTransformOverrideFactory parMultiDoReplacementFactory( + PCollection> inputs, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn) { + + return new PTransformOverrideFactory<>() { + @Override + public PTransformReplacement getReplacementTransform(AppliedPTransform transform) { + return replaceParMultiDo(transform, inputs, stateWriteInstant, nextFlushInstantFn); + } + + @SuppressWarnings("unchecked") + @Override + public Map, ReplacementOutput> mapOutputs(Map outputs, POutput newOutput) { + return ReplacementOutputs.tagged(outputs, newOutput); + } + }; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static PTransformReplacement replaceParMultiDo( + AppliedPTransform transform, + PCollection> inputs, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn) { + + ParDo.MultiOutput rawTransform = + (ParDo.MultiOutput) (PTransform) transform.getTransform(); + DoFn, ?> doFn = (DoFn) rawTransform.getFn(); + PInput pMainInput = getMainInput(transform); + if (!DoFnSignatures.isStateful(doFn)) { + return PTransformReplacement.of(pMainInput, (PTransform) transform.getTransform()); + } + String transformName = transform.getFullName(); + PCollection transformInputs = + inputs + .apply(Filter.by(kv -> kv.getKey().equals(transformName))) + .apply(MapElements.into(TypeDescriptor.of(StateValue.class)).via(KV::getValue)); + TupleTag mainOutputTag = rawTransform.getMainOutputTag(); + return PTransformReplacement.of( + pMainInput, + transformedParDo( + transformInputs, + (DoFn) doFn, + mainOutputTag, + TupleTagList.of( + transform.getOutputs().keySet().stream() + .filter(t -> !t.equals(mainOutputTag)) + .collect(Collectors.toList())), + stateWriteInstant, + nextFlushInstantFn)); + } + + @SuppressWarnings("unchecked") + private static , OutputT> + PTransform, PCollectionTuple> transformedParDo( + PCollection transformInputs, + DoFn, OutputT> doFn, + TupleTag mainOutputTag, + TupleTagList otherOutputs, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn) { + + return new PTransform<>() { + @Override + public PCollectionTuple expand(PCollection input) { + @SuppressWarnings("unchecked") + KvCoder coder = (KvCoder) input.getCoder(); + Coder keyCoder = coder.getKeyCoder(); + Coder valueCoder = coder.getValueCoder(); + TypeDescriptor> valueDescriptor = + new TypeDescriptor<>(new TypeToken>() {}) {}; + PCollection>> state = + transformInputs + .apply( + MapElements.into( + TypeDescriptors.kvs( + keyCoder.getEncodedTypeDescriptor(), valueDescriptor)) + .via( + e -> + ExceptionUtils.uncheckedFactory( + () -> + KV.of( + CoderUtils.decodeFromByteArray(keyCoder, e.getKey()), + StateOrInput.state(e))))) + .setCoder(KvCoder.of(keyCoder, StateOrInput.coder(valueCoder))); + PCollection>> inputs = + input + .apply( + MapElements.into( + TypeDescriptors.kvs( + keyCoder.getEncodedTypeDescriptor(), valueDescriptor)) + .via(e -> KV.of(e.getKey(), StateOrInput.input(e.getValue())))) + .setCoder(KvCoder.of(keyCoder, StateOrInput.coder(valueCoder))); + PCollection>> flattened = + PCollectionList.of(state).and(inputs).apply(Flatten.pCollections()); + PCollectionTuple tuple = + flattened.apply( + ParDo.of( + transformedDoFn( + doFn, + (KvCoder) input.getCoder(), + mainOutputTag, + stateWriteInstant, + nextFlushInstantFn)) + .withOutputTags(mainOutputTag, otherOutputs.and(STATE_TUPLE_TAG))); + PCollectionTuple res = PCollectionTuple.empty(input.getPipeline()); + for (Entry, PCollection> e : + (Set, PCollection>>) (Set) tuple.getAll().entrySet()) { + if (!e.getKey().equals(STATE_TUPLE_TAG)) { + res = res.and(e.getKey(), e.getValue()); + } + } + return res; + } + }; + } + + @VisibleForTesting + static >, OutputT> + DoFn transformedDoFn( + DoFn, OutputT> doFn, + KvCoder inputCoder, + TupleTag mainTag, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn) { + + @SuppressWarnings("unchecked") + Class, OutputT>> doFnClass = + (Class, OutputT>>) doFn.getClass(); + + ClassLoadingStrategy strategy = ByteBuddyUtils.getClassLoadingStrategy(doFnClass); + final String className = + doFnClass.getName() + + "$Expanded" + + (Objects.hash(stateWriteInstant, nextFlushInstantFn) & Integer.MAX_VALUE); + final ClassLoader classLoader = ExternalStateExpander.class.getClassLoader(); + try { + @SuppressWarnings("unchecked") + Class> aClass = + (Class>) classLoader.loadClass(className); + // class found, return instance + return ExceptionUtils.uncheckedFactory( + () -> aClass.getConstructor(doFnClass).newInstance(doFn)); + } catch (ClassNotFoundException e) { + // class not found, create it + } + + ByteBuddy buddy = new ByteBuddy(); + @SuppressWarnings("unchecked") + ParameterizedType parameterizedSuperClass = + getParameterizedDoFn((Class, OutputT>>) doFn.getClass()); + ParameterizedType inputType = + (ParameterizedType) parameterizedSuperClass.getActualTypeArguments()[0]; + Preconditions.checkArgument( + inputType.getRawType().equals(KV.class), + "Input type to stateful DoFn must be KV, go %s", + inputType); + + Type outputType = parameterizedSuperClass.getActualTypeArguments()[1]; + Generic wrapperInput = getWrapperInputType(inputType); + + Generic doFnGeneric = + Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(DoFn.class), + wrapperInput, + TypeDescription.Generic.Builder.of(outputType).build()) + .build(); + @SuppressWarnings("unchecked") + Builder> builder = + (Builder>) + buddy + .subclass(doFnGeneric) + .name(className) + .defineField(DELEGATE_FIELD_NAME, doFnClass, Visibility.PRIVATE); + builder = addStateAndTimers(doFnClass, inputType, builder); + builder = + builder + .defineConstructor(Visibility.PUBLIC) + .withParameters(doFnClass) + .intercept( + addStateAndTimerValues( + doFn, + inputCoder, + MethodCall.invoke( + ExceptionUtils.uncheckedFactory(() -> DoFn.class.getConstructor())) + .andThen(FieldAccessor.ofField(DELEGATE_FIELD_NAME).setsArgumentAt(0)))); + + builder = + addProcessingMethods( + doFn, + inputType, + inputCoder.getKeyCoder(), + mainTag, + outputType, + stateWriteInstant, + nextFlushInstantFn, + buddy, + builder); + Unloaded> dynamicClass = builder.make(); + return ExceptionUtils.uncheckedFactory( + () -> + dynamicClass + .load(null, strategy) + .getLoaded() + .getDeclaredConstructor(doFnClass) + .newInstance(doFn)); + } + + private static , OutputT> Implementation addStateAndTimerValues( + DoFn doFn, Coder> inputCoder, Composable delegate) { + + List> acceptable = Arrays.asList(StateId.class, TimerId.class); + @SuppressWarnings("unchecked") + Class> doFnClass = + (Class>) doFn.getClass(); + for (Field f : doFnClass.getDeclaredFields()) { + if (!Modifier.isStatic(f.getModifiers()) + && acceptable.stream().anyMatch(a -> f.getAnnotation(a) != null)) { + f.setAccessible(true); + Object value = ExceptionUtils.uncheckedFactory(() -> f.get(doFn)); + delegate = delegate.andThen(FieldAccessor.ofField(f.getName()).setsValue(value)); + } + } + delegate = + delegate + .andThen( + FieldAccessor.ofField(EXPANDER_BUF_STATE_SPEC) + .setsValue(StateSpecs.bag(TimestampedValueCoder.of(inputCoder)))) + .andThen(FieldAccessor.ofField(EXPANDER_FLUSH_STATE_SPEC).setsValue(StateSpecs.value())) + .andThen( + FieldAccessor.ofField(EXPANDER_TIMER_SPEC) + .setsValue(TimerSpecs.timer(TimeDomain.EVENT_TIME))); + return delegate; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static ParameterizedType getParameterizedDoFn( + Class> doFnClass) { + + Type type = doFnClass.getGenericSuperclass(); + if (type instanceof ParameterizedType) { + return (ParameterizedType) type; + } + if (doFnClass.getSuperclass().isAssignableFrom(DoFn.class)) { + return getParameterizedDoFn((Class) doFnClass.getGenericSuperclass()); + } + throw new IllegalStateException("Cannot get parameterized type of " + doFnClass); + } + + private static >, OutputT> + Builder> addProcessingMethods( + DoFn, OutputT> doFn, + ParameterizedType inputType, + Coder keyCoder, + TupleTag mainTag, + Type outputType, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn, + ByteBuddy buddy, + Builder> builder) { + + builder = addProcessingMethod(doFn, DoFn.Setup.class, builder); + builder = addProcessingMethod(doFn, DoFn.StartBundle.class, builder); + builder = + addProcessElementMethod( + doFn, inputType, mainTag, outputType, stateWriteInstant, buddy, builder); + builder = addProcessingMethod(doFn, DoFn.FinishBundle.class, builder); + builder = addProcessingMethod(doFn, DoFn.Teardown.class, builder); + builder = addOnWindowExpirationMethod(doFn, inputType, mainTag, buddy, builder); + builder = addProcessingMethod(doFn, DoFn.GetInitialRestriction.class, builder); + builder = addProcessingMethod(doFn, DoFn.SplitRestriction.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetRestrictionCoder.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetWatermarkEstimatorStateCoder.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetInitialWatermarkEstimatorState.class, builder); + builder = addProcessingMethod(doFn, DoFn.NewWatermarkEstimator.class, builder); + builder = addProcessingMethod(doFn, DoFn.NewTracker.class, builder); + builder = addProcessingMethod(doFn, DoFn.OnTimer.class, builder); + builder = + addTimerFlushMethod( + doFn, + inputType, + keyCoder, + mainTag, + STATE_TUPLE_TAG, + outputType, + nextFlushInstantFn, + buddy, + builder); + return builder; + } + + private static >, OutputT> + Builder> addProcessElementMethod( + DoFn, OutputT> doFn, + ParameterizedType inputType, + TupleTag mainTag, + Type outputType, + Instant stateWriteInstant, + ByteBuddy buddy, + Builder> builder) { + + Class annotation = ProcessElement.class; + Method method = findMethod(doFn, annotation); + if (method != null) { + ProcessElementParameterExpander expander = + ProcessElementParameterExpander.of( + doFn, method, inputType, mainTag, outputType, stateWriteInstant); + List> wrapperArgs = expander.getWrapperArgs(); + MethodDefinition> methodDefinition = + builder + .defineMethod(method.getName(), method.getReturnType(), Visibility.PUBLIC) + .withParameters( + wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) + .intercept( + MethodDelegation.to( + new ProcessElementInterceptor<>(doFn, expander, method, buddy))); + + for (int i = 0; i < wrapperArgs.size(); i++) { + Pair arg = wrapperArgs.get(i); + if (arg.getFirst() != null) { + methodDefinition = methodDefinition.annotateParameter(i, arg.getFirst()); + } + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(annotation).build()); + } + return builder; + } + + private static >, OutputT> + Builder> addOnWindowExpirationMethod( + DoFn, OutputT> doFn, + ParameterizedType inputType, + TupleTag mainTag, + ByteBuddy buddy, + Builder> builder) { + + Class annotation = DoFn.OnWindowExpiration.class; + @Nullable Method onWindowExpirationMethod = findMethod(doFn, annotation); + Method processElementMethod = findMethod(doFn, DoFn.ProcessElement.class); + Type outputType = doFn.getOutputTypeDescriptor().getType(); + if (processElementMethod != null) { + OnWindowParameterExpander expander = + OnWindowParameterExpander.of( + inputType, processElementMethod, onWindowExpirationMethod, mainTag, outputType); + List> wrapperArgs = expander.getWrapperArgs(); + MethodDefinition> methodDefinition = + builder + .defineMethod( + onWindowExpirationMethod == null + ? "_onWindowExpiration" + : onWindowExpirationMethod.getName(), + void.class, + Visibility.PUBLIC) + .withParameters( + wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) + .intercept( + MethodDelegation.to( + new OnWindowExpirationInterceptor<>( + doFn, processElementMethod, onWindowExpirationMethod, expander, buddy))); + + // retrieve parameter annotations and apply them + for (int i = 0; i < wrapperArgs.size(); i++) { + AnnotationDescription ann = wrapperArgs.get(i).getFirst(); + if (ann != null) { + methodDefinition = methodDefinition.annotateParameter(i, ann); + } + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(annotation).build()); + } + return builder; + } + + private static >> + Builder> addTimerFlushMethod( + DoFn, OutputT> doFn, + ParameterizedType inputType, + Coder keyCoder, + TupleTag mainTag, + TupleTag stateTag, + Type outputType, + UnaryFunction nextFlushInstantFn, + ByteBuddy buddy, + Builder> builder) { + + Method processElement = findMethod(doFn, ProcessElement.class); + FlushTimerParameterExpander expander = + FlushTimerParameterExpander.of(doFn, inputType, processElement, mainTag, outputType); + List> wrapperArgs = expander.getWrapperArgs(); + MethodDefinition> methodDefinition = + builder + .defineMethod("expanderFlushTimer", void.class, Visibility.PUBLIC) + .withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) + .intercept( + MethodDelegation.to( + new FlushTimerInterceptor<>( + doFn, + processElement, + expander, + keyCoder, + stateTag, + nextFlushInstantFn, + buddy))); + int i = 0; + for (Pair p : wrapperArgs) { + if (p.getFirst() != null) { + methodDefinition = methodDefinition.annotateParameter(i, p.getFirst()); + } + i++; + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(DoFn.OnTimer.class) + .define("value", EXPANDER_TIMER_NAME) + .build()); + } + + private static Method findMethod( + DoFn, OutputT> doFn, Class annotation) { + + return Iterables.getOnlyElement( + Arrays.stream(doFn.getClass().getMethods()) + .filter(m -> m.getAnnotation(annotation) != null) + .collect(Collectors.toList()), + null); + } + + private static >, OutputT, T extends Annotation> + Builder> addProcessingMethod( + DoFn, OutputT> doFn, + Class annotation, + Builder> builder) { + + Method method = findMethod(doFn, annotation); + if (method != null) { + MethodDefinition> methodDefinition = + builder + .defineMethod(method.getName(), method.getReturnType(), Visibility.PUBLIC) + .withParameters(method.getGenericParameterTypes()) + .intercept(MethodCall.invoke(method).onField(DELEGATE_FIELD_NAME).withAllArguments()); + + // retrieve parameter annotations and apply them + Annotation[][] parameterAnnotations = method.getParameterAnnotations(); + for (int i = 0; i < parameterAnnotations.length; i++) { + for (Annotation paramAnnotation : parameterAnnotations[i]) { + methodDefinition = methodDefinition.annotateParameter(i, paramAnnotation); + } + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(annotation).build()); + } + return builder; + } + + private static >, OutputT> + Builder> addStateAndTimers( + Class, OutputT>> doFnClass, + ParameterizedType inputType, + Builder> builder) { + + builder = cloneFields(doFnClass, StateId.class, builder); + builder = cloneFields(doFnClass, TimerId.class, builder); + builder = addBufferingStatesAndTimer(inputType, builder); + return builder; + } + + /** Add state that buffers inputs until we process all state updates. */ + private static >, OutputT> + Builder> addBufferingStatesAndTimer( + ParameterizedType inputType, Builder> builder) { + + // type: StateSpec>> + Generic bufStateSpecFieldType = + Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(StateSpec.class), bagStateFromInputType(inputType)) + .build(); + // type: StateSpec> + Generic finishedStateSpecFieldType = + Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(StateSpec.class), + Generic.Builder.parameterizedType(ValueState.class, Instant.class).build()) + .build(); + + Generic timerSpecFieldType = Generic.Builder.of(TimerSpec.class).build(); + + builder = + defineStateField( + builder, + bufStateSpecFieldType, + DoFn.StateId.class, + EXPANDER_BUF_STATE_SPEC, + EXPANDER_BUF_STATE_NAME); + builder = + defineStateField( + builder, + finishedStateSpecFieldType, + DoFn.StateId.class, + EXPANDER_FLUSH_STATE_SPEC, + EXPANDER_FLUSH_STATE_NAME); + builder = + defineStateField( + builder, + timerSpecFieldType, + DoFn.TimerId.class, + EXPANDER_TIMER_SPEC, + EXPANDER_TIMER_NAME); + + return builder; + } + + private static >, OutputT> + Builder> defineStateField( + Builder> builder, + Generic stateSpecFieldType, + Class annotation, + String fieldName, + String name) { + + return builder + .defineField( + fieldName, + stateSpecFieldType, + Visibility.PUBLIC.getMask() + FieldManifestation.FINAL.getMask()) + .annotateField( + AnnotationDescription.Builder.ofType(annotation).define("value", name).build()); + } + + private static >, OutputT, T extends Annotation> + Builder> cloneFields( + Class, OutputT>> doFnClass, + Class annotationClass, + Builder> builder) { + + for (Field f : doFnClass.getDeclaredFields()) { + if (!Modifier.isStatic(f.getModifiers()) && f.getAnnotation(annotationClass) != null) { + builder = + builder + .defineField(f.getName(), f.getGenericType(), f.getModifiers()) + .annotateField(f.getDeclaredAnnotations()); + } + } + return builder; + } + + private static PInput getMainInput(AppliedPTransform transform) { + Map, PCollection> mainInputs = transform.getMainInputs(); + if (mainInputs.size() == 1) { + return Iterables.getOnlyElement(mainInputs.values()); + } + return asTuple(mainInputs); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static @NonNull PInput asTuple(Map, PCollection> mainInputs) { + Preconditions.checkArgument(!mainInputs.isEmpty()); + PCollectionTuple res = null; + for (Map.Entry, PCollection> e : mainInputs.entrySet()) { + if (res == null) { + res = PCollectionTuple.of((TupleTag) e.getKey(), e.getValue()); + } else { + res = res.and((TupleTag) e.getKey(), e.getValue()); + } + } + return Objects.requireNonNull(res); + } + + private static class ProcessElementInterceptor { + + private final DoFn, ?> doFn; + private final ProcessElementParameterExpander expander; + private final UnaryFunction processFn; + private final VoidMethodInvoker invoker; + + ProcessElementInterceptor( + DoFn, ?> doFn, + ProcessElementParameterExpander expander, + Method process, + ByteBuddy buddy) { + + this.doFn = doFn; + this.expander = expander; + this.processFn = expander.getProcessFn(); + this.invoker = ExceptionUtils.uncheckedFactory(() -> VoidMethodInvoker.of(process, buddy)); + } + + @RuntimeType + public void intercept( + @This DoFn>, ?> proxy, @AllArguments Object[] allArgs) { + + if (processFn.apply(allArgs)) { + Object[] methodArgs = expander.getProcessElementArgs(allArgs); + ExceptionUtils.unchecked(() -> invoker.invoke(doFn, methodArgs)); + } + } + } + + private static class OnWindowExpirationInterceptor { + private final DoFn, ?> doFn; + private final VoidMethodInvoker, ?>> processElement; + private final @Nullable VoidMethodInvoker, ?>> onWindowExpiration; + private final OnWindowParameterExpander expander; + + public OnWindowExpirationInterceptor( + DoFn, ?> doFn, + Method processElementMethod, + @Nullable Method onWindowExpirationMethod, + OnWindowParameterExpander expander, + ByteBuddy buddy) { + + this.doFn = doFn; + this.processElement = + ExceptionUtils.uncheckedFactory(() -> VoidMethodInvoker.of(processElementMethod, buddy)); + this.onWindowExpiration = + onWindowExpirationMethod == null + ? null + : ExceptionUtils.uncheckedFactory( + () -> VoidMethodInvoker.of(onWindowExpirationMethod, buddy)); + this.expander = expander; + } + + @RuntimeType + public void intercept( + @This DoFn>, ?> proxy, @AllArguments Object[] allArgs) { + + @SuppressWarnings("unchecked") + BagState>> buf = + (BagState>>) allArgs[allArgs.length - 1]; + Iterable>> buffered = buf.read(); + // feed all data to @ProcessElement + for (TimestampedValue> kv : buffered) { + ExceptionUtils.unchecked( + () -> processElement.invoke(doFn, expander.getProcessElementArgs(kv, allArgs))); + } + // invoke onWindowExpiration + if (onWindowExpiration != null) { + ExceptionUtils.unchecked( + () -> onWindowExpiration.invoke(doFn, expander.getOnWindowExpirationArgs(allArgs))); + } + } + } + + private static class FlushTimerInterceptor { + + private final DoFn, ?> doFn; + private final LinkedHashMap>> + stateReaders; + private final VoidMethodInvoker processElementMethod; + private final FlushTimerParameterExpander expander; + private final Coder keyCoder; + private final TupleTag stateTag; + private final UnaryFunction nextFlushInstantFn; + + FlushTimerInterceptor( + DoFn, ?> doFn, + Method processElementMethod, + FlushTimerParameterExpander expander, + Coder keyCoder, + TupleTag stateTag, + UnaryFunction nextFlushInstantFn, + ByteBuddy buddy) { + + this.doFn = doFn; + this.stateReaders = getStateReaders(doFn); + this.processElementMethod = + ExceptionUtils.uncheckedFactory(() -> VoidMethodInvoker.of(processElementMethod, buddy)); + this.expander = expander; + this.keyCoder = keyCoder; + this.stateTag = stateTag; + this.nextFlushInstantFn = nextFlushInstantFn; + } + + @RuntimeType + public void intercept(@This DoFn>, ?> doFn, @AllArguments Object[] args) { + Instant now = (Instant) args[args.length - 6]; + @SuppressWarnings("unchecked") + K key = (K) args[args.length - 5]; + @SuppressWarnings("unchecked") + ValueState nextFlushState = (ValueState) args[args.length - 3]; + Timer flushTimer = (Timer) args[args.length - 4]; + Instant nextFlush = nextFlushInstantFn.apply(now); + Instant lastFlush = nextFlushState.read(); + boolean isNextScheduled = + nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE); + if (isNextScheduled) { + flushTimer.set(nextFlush); + nextFlushState.write(nextFlush); + } + @SuppressWarnings("unchecked") + BagState>> bufState = + (BagState>>) args[args.length - 2]; + @SuppressWarnings({"unchecked", "rawtypes"}) + List>> pushedBackElements = + processBuffer(args, (Iterable) bufState.read(), MoreObjects.firstNonNull(lastFlush, now)); + bufState.clear(); + // if we have already processed state data + if (lastFlush != null) { + MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1]; + OutputReceiver output = outputReceiver.get(stateTag); + byte[] keyBytes = + ExceptionUtils.uncheckedFactory(() -> CoderUtils.encodeToByteArray(keyCoder, key)); + int i = 0; + for (BiFunction> f : stateReaders.values()) { + Object accessor = args[i++]; + Iterable values = f.apply(accessor, keyBytes); + values.forEach(output::output); + } + } + @SuppressWarnings({"unchecked", "rawtypes"}) + List>> remaining = + processBuffer( + args, + (List) pushedBackElements, + MoreObjects.firstNonNull(nextFlush, BoundedWindow.TIMESTAMP_MAX_VALUE)); + remaining.forEach(bufState::add); + } + + private List>> processBuffer( + Object[] args, Iterable>> buffer, Instant maxTs) { + + List>> pushedBackElements = new ArrayList<>(); + buffer.forEach( + kv -> { + if (kv.getTimestamp().isBefore(maxTs)) { + Object[] processArgs = expander.getProcessElementArgs(kv, args); + ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs)); + } else { + // return back to buffer + log.debug("Returning element {} to flush buffer", kv); + pushedBackElements.add(kv); + } + }); + @SuppressWarnings({"unchecked", "rawtypes"}) + List>> res = (List) pushedBackElements; + return res; + } + } + + static Generic bagStateFromInputType(ParameterizedType inputType) { + return Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(BagState.class), + Generic.Builder.parameterizedType(TimestampedValue.class, inputType).build()) + .build(); + } + + private static class StateTupleTag extends TupleTag {} + + // do not construct + private ExternalStateExpander() {} +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java new file mode 100644 index 000000000..3d81293cd --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java @@ -0,0 +1,150 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static cz.o2.proxima.beam.util.state.ExternalStateExpander.*; +import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; + +import cz.o2.proxima.core.util.Pair; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.stream.Collectors; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Instant; + +interface FlushTimerParameterExpander { + + static FlushTimerParameterExpander of( + DoFn doFn, + ParameterizedType inputType, + Method processElement, + TupleTag mainTag, + Type outputType) { + + final LinkedHashMap> processArgs = extractArgs(processElement); + final LinkedHashMap> wrapperArgs = + createWrapperArgs(doFn, inputType); + final List>, Object>> + processArgsGenerators = projectArgs(wrapperArgs, processArgs, mainTag, outputType); + + return new FlushTimerParameterExpander() { + @Override + public List> getWrapperArgs() { + return new ArrayList<>(wrapperArgs.values()); + } + + @Override + public Object[] getProcessElementArgs( + TimestampedValue> input, Object[] wrapperArgs) { + return fromGenerators(input, processArgsGenerators, wrapperArgs); + } + }; + } + + private static LinkedHashMap> + createWrapperArgs(DoFn doFn, ParameterizedType inputType) { + + List> states = + Arrays.stream(doFn.getClass().getDeclaredFields()) + .filter(f -> f.getAnnotation(DoFn.StateId.class) != null) + .map( + f -> { + Preconditions.checkArgument( + f.getGenericType() instanceof ParameterizedType, + "Field %s has invalid type %s", + f.getName(), + f.getGenericType()); + return Pair.of( + (Annotation) f.getAnnotation(DoFn.StateId.class), + ((ParameterizedType) f.getGenericType()).getActualTypeArguments()[0]); + }) + .collect(Collectors.toList()); + + List> types = + states.stream() + .map( + p -> + Pair.of( + (AnnotationDescription) + AnnotationDescription.ForLoadedAnnotation.of(p.getFirst()), + (TypeDefinition) TypeDescription.Generic.Builder.of(p.getSecond()).build())) + .collect(Collectors.toList()); + // add parameter for timestamp, key, timer, state and output + types.add( + Pair.of( + AnnotationDescription.Builder.ofType(DoFn.Timestamp.class).build(), + TypeDescription.ForLoadedType.of(Instant.class))); + types.add( + Pair.of( + AnnotationDescription.Builder.ofType(DoFn.Key.class).build(), + TypeDescription.Generic.Builder.of(inputType.getActualTypeArguments()[0]).build())); + types.add( + Pair.of( + AnnotationDescription.Builder.ofType(DoFn.TimerId.class) + .define("value", EXPANDER_TIMER_NAME) + .build(), + TypeDescription.ForLoadedType.of(Timer.class))); + types.add( + Pair.of( + AnnotationDescription.Builder.ofType(DoFn.StateId.class) + .define("value", EXPANDER_FLUSH_STATE_NAME) + .build(), + TypeDescription.Generic.Builder.parameterizedType(ValueState.class, Instant.class) + .build())); + types.add( + Pair.of( + AnnotationDescription.Builder.ofType(DoFn.StateId.class) + .define("value", EXPANDER_BUF_STATE_NAME) + .build(), + bagStateFromInputType(inputType))); + types.add(Pair.of(null, TypeDescription.ForLoadedType.of(DoFn.MultiOutputReceiver.class))); + + LinkedHashMap> res = new LinkedHashMap<>(); + types.forEach( + p -> { + TypeId id = p.getFirst() == null ? TypeId.of(p.getSecond()) : TypeId.of(p.getFirst()); + res.put(id, p); + }); + return res; + } + + /** + * Get arguments that must be declared by wrapper's call for both {@code @}ProcessElement and + * {@code @}OnWindowExpiration be callable. + */ + List> getWrapperArgs(); + + /** + * Get parameters that should be passed to {@code @}ProcessElement from wrapper's + * {@code @}OnWindowExpiration + */ + Object[] getProcessElementArgs(TimestampedValue> input, Object[] wrapperArgs); +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java new file mode 100644 index 000000000..26daf5c57 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java @@ -0,0 +1,840 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import cz.o2.proxima.core.functional.BiConsumer; +import cz.o2.proxima.core.util.ExceptionUtils; +import cz.o2.proxima.core.util.Pair; +import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import cz.o2.proxima.internal.com.google.common.collect.Iterables; +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.method.ParameterDescription; +import net.bytebuddy.description.modifier.Visibility; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription.ForLoadedType; +import net.bytebuddy.description.type.TypeDescription.Generic; +import net.bytebuddy.description.type.TypeDescription.Generic.Builder; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.dynamic.scaffold.InstrumentedType; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.MethodCall; +import net.bytebuddy.implementation.MethodCall.ArgumentLoader; +import net.bytebuddy.implementation.MethodCall.ArgumentLoader.ArgumentProvider; +import net.bytebuddy.implementation.MethodCall.ArgumentLoader.Factory; +import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.assign.Assigner; +import net.bytebuddy.implementation.bytecode.assign.Assigner.Typing; +import net.bytebuddy.implementation.bytecode.collection.ArrayAccess; +import net.bytebuddy.implementation.bytecode.constant.IntegerConstant; +import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.MultimapState; +import org.apache.beam.sdk.state.OrderedListState; +import org.apache.beam.sdk.state.SetState; +import org.apache.beam.sdk.state.StateBinder; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.ByteBuddyUtils; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +@Slf4j +class MethodCallUtils { + + static Object[] fromGenerators( + List>, Object>> generators, + Object[] wrapperArgs) { + + return fromGenerators(null, generators, wrapperArgs); + } + + static Object[] fromGenerators( + @Nullable TimestampedValue> elem, + List>, Object>> generators, + Object[] wrapperArgs) { + + Object[] res = new Object[generators.size()]; + for (int i = 0; i < generators.size(); i++) { + res[i] = generators.get(i).apply(wrapperArgs, elem); + } + return res; + } + + static List>, Object>> projectArgs( + LinkedHashMap> wrapperArgList, + LinkedHashMap> argsMap, + TupleTag mainTag, + Type outputType) { + + List>, Object>> res = + new ArrayList<>(argsMap.size()); + List wrapperParamsIds = + wrapperArgList.values().stream() + .map(p -> p.getFirst() != null ? p.getFirst().getAnnotationType() : p.getSecond()) + .collect(Collectors.toList()); + for (Map.Entry> e : argsMap.entrySet()) { + int wrapperArg = findArgIndex(wrapperArgList.keySet(), e.getKey()); + if (wrapperArg < 0) { + addUnknownWrapperArgument( + mainTag, outputType, e.getKey(), wrapperParamsIds, wrapperArgList.keySet(), res); + } else { + remapKnownArgument(outputType, e.getKey(), wrapperArg, res); + } + } + return res; + } + + private static void remapKnownArgument( + Type outputType, + TypeId typeId, + int wrapperArg, + List>, Object>> res) { + + if (typeId.isElement()) { + // this applies to @ProcessElement only, the input element holds KV + res.add( + (args, elem) -> + extractValue((TimestampedValue>>) args[wrapperArg])); + } else if (typeId.isTimestamp()) { + res.add((args, elem) -> elem == null ? args[wrapperArg] : elem.getTimestamp()); + } else if (typeId.isOutput(outputType)) { + remapOutputs(typeId, wrapperArg, res); + } else { + res.add((args, elem) -> args[wrapperArg]); + } + } + + private static void remapOutputs( + TypeId typeId, + int wrapperArg, + List>, Object>> res) { + if (typeId.isMultiOutput()) { + res.add( + (args, elem) -> + elem == null + ? args[wrapperArg] + : remapTimestampIfNeeded((MultiOutputReceiver) args[wrapperArg], elem)); + } else { + res.add( + (args, elem) -> { + if (elem == null) { + return args[wrapperArg]; + } + OutputReceiver parent = (OutputReceiver) args[wrapperArg]; + return new TimestampedOutputReceiver<>(parent, elem.getTimestamp()); + }); + } + } + + private static void addUnknownWrapperArgument( + TupleTag mainTag, + Type outputType, + TypeId typeId, + List wrapperParamsIds, + Iterable wrapperIds, + List>, Object>> res) { + + // the wrapper does not have the required argument + if (typeId.isElement()) { + // wrapper does not have @Element, we need to provide it from input element + res.add((args, elem) -> Objects.requireNonNull(elem.getValue())); + } else if (typeId.isTimestamp()) { + // wrapper does not have timestamp + res.add((args, elem) -> elem.getTimestamp()); + } else if (typeId.isOutput(outputType)) { + int wrapperPos = wrapperParamsIds.indexOf(ForLoadedType.of(MultiOutputReceiver.class)); + if (typeId.isMultiOutput()) { + // inject timestamp + res.add( + (args, elem) -> remapTimestampIfNeeded((MultiOutputReceiver) args[wrapperPos], elem)); + } else { + // remap MultiOutputReceiver to OutputReceiver + Preconditions.checkState(wrapperPos >= 0); + res.add( + (args, elem) -> singleOutput((MultiOutputReceiver) args[wrapperPos], elem, mainTag)); + } + } else { + throw new IllegalStateException( + String.format( + "Missing argument %s in wrapper. Available options are %s", typeId, wrapperIds)); + } + } + + private static DoFn.MultiOutputReceiver remapTimestampIfNeeded( + MultiOutputReceiver parent, @Nullable TimestampedValue> elem) { + + if (elem == null) { + return parent; + } + return new MultiOutputReceiver() { + @Override + public OutputReceiver get(TupleTag tag) { + OutputReceiver parentReceiver = parent.get(tag); + return new TimestampedOutputReceiver<>(parentReceiver, elem.getTimestamp()); + } + + @Override + public OutputReceiver getRowReceiver(TupleTag tag) { + OutputReceiver parentReceiver = parent.getRowReceiver(tag); + return new TimestampedOutputReceiver<>(parentReceiver, elem.getTimestamp()); + } + }; + } + + private static KV extractValue(TimestampedValue>> arg) { + Preconditions.checkArgument(!arg.getValue().getValue().isState()); + return KV.of(arg.getValue().getKey(), arg.getValue().getValue().getInput()); + } + + private static int findArgIndex(Collection collection, TypeId key) { + int i = 0; + for (TypeId t : collection) { + if (key.equals(t)) { + return i; + } + i++; + } + return -1; + } + + static LinkedHashMap> extractArgs(Method method) { + LinkedHashMap> res = new LinkedHashMap<>(); + if (method != null) { + for (int i = 0; i < method.getParameterCount(); i++) { + Type parameterType = method.getGenericParameterTypes()[i]; + verifyArg(parameterType); + Annotation[] annotations = method.getParameterAnnotations()[i]; + TypeId paramId = + annotations.length > 0 + ? TypeId.of(getSingleAnnotation(annotations)) + : TypeId.of(parameterType); + res.put(paramId, Pair.of(annotations.length == 0 ? null : annotations[0], parameterType)); + } + } + return res; + } + + private static void verifyArg(Type parameterType) { + Preconditions.checkArgument( + !(parameterType instanceof DoFn.ProcessContext), + "ProcessContext is not supported. Please use the new-style @Element, @Timestamp, etc."); + } + + static Annotation getSingleAnnotation(Annotation[] annotations) { + Preconditions.checkArgument(annotations.length == 1, Arrays.toString(annotations)); + return annotations[0]; + } + + static Generic getInputKvType(ParameterizedType inputType) { + Type keyType = inputType.getActualTypeArguments()[0]; + Type valueType = inputType.getActualTypeArguments()[1]; + + // generic type: KV + return Generic.Builder.parameterizedType(KV.class, keyType, valueType).build(); + } + + private static OutputReceiver singleOutput( + MultiOutputReceiver multiOutput, + @Nullable TimestampedValue> elem, + TupleTag mainTag) { + + return new OutputReceiver() { + @Override + public void output(T output) { + if (elem == null) { + multiOutput.get(mainTag).output(output); + } else { + multiOutput.get(mainTag).outputWithTimestamp(output, elem.getTimestamp()); + } + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + multiOutput.get(mainTag).outputWithTimestamp(output, timestamp); + } + }; + } + + static Generic getWrapperInputType(ParameterizedType inputType) { + Type kType = inputType.getActualTypeArguments()[0]; + Type vType = inputType.getActualTypeArguments()[1]; + return Generic.Builder.parameterizedType( + ForLoadedType.of(KV.class), + Generic.Builder.of(kType).build(), + Generic.Builder.parameterizedType(StateOrInput.class, vType).build()) + .build(); + } + + static Map> getStateUpdaters(DoFn doFn) { + Field[] fields = doFn.getClass().getDeclaredFields(); + return Arrays.stream(fields) + .map(f -> Pair.of(f, f.getAnnotation(DoFn.StateId.class))) + .filter(p -> p.getSecond() != null) + .map( + p -> { + p.getFirst().setAccessible(true); + return p; + }) + .map( + p -> + Pair.of( + p.getSecond().value(), + createUpdater( + ((StateSpec) + ExceptionUtils.uncheckedFactory(() -> p.getFirst().get(doFn)))))) + .filter(p -> p.getSecond() != null) + .collect(Collectors.toMap(Pair::getFirst, Pair::getSecond)); + } + + @SuppressWarnings("unchecked") + private static @Nullable BiConsumer createUpdater(StateSpec stateSpec) { + AtomicReference> consumer = new AtomicReference<>(); + stateSpec.bind("dummy", createUpdaterBinder(consumer)); + return consumer.get(); + } + + static LinkedHashMap>> getStateReaders( + DoFn doFn) { + + Field[] fields = doFn.getClass().getDeclaredFields(); + LinkedHashMap>> res = + new LinkedHashMap<>(); + Arrays.stream(fields) + .map(f -> Pair.of(f, f.getAnnotation(DoFn.StateId.class))) + .filter(p -> p.getSecond() != null) + .map( + p -> { + p.getFirst().setAccessible(true); + return p; + }) + .map( + p -> + Pair.of( + p.getSecond().value(), + createReader( + ((StateSpec) + ExceptionUtils.uncheckedFactory(() -> p.getFirst().get(doFn)))))) + .filter(p -> p.getSecond() != null) + .forEachOrdered(p -> res.put(p.getFirst(), p.getSecond())); + return res; + } + + @SuppressWarnings("unchecked") + private static @Nullable BiFunction> createReader( + StateSpec stateSpec) { + AtomicReference>> res = new AtomicReference<>(); + stateSpec.bind("dummy", createStateReaderBinder(res)); + return res.get(); + } + + @VisibleForTesting + static StateBinder createUpdaterBinder(AtomicReference> consumer) { + return new StateBinder() { + @Override + public @Nullable ValueState bindValue( + String id, StateSpec> spec, Coder coder) { + consumer.set( + (accessor, value) -> + ((ValueState) accessor) + .write( + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(coder, value.getValue())))); + return null; + } + + @Override + public @Nullable BagState bindBag( + String id, StateSpec> spec, Coder elemCoder) { + consumer.set( + (accessor, value) -> + ((BagState) accessor) + .add( + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(elemCoder, value.getValue())))); + return null; + } + + @Override + public @Nullable SetState bindSet( + String id, StateSpec> spec, Coder elemCoder) { + consumer.set( + (accessor, value) -> + ((SetState) accessor) + .add( + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(elemCoder, value.getValue())))); + return null; + } + + @Override + public @Nullable MapState bindMap( + String id, + StateSpec> spec, + Coder mapKeyCoder, + Coder mapValueCoder) { + KvCoder coder = KvCoder.of(mapKeyCoder, mapValueCoder); + consumer.set( + (accessor, value) -> { + KV decoded = + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(coder, value.getValue())); + ((MapState) accessor).put(decoded.getKey(), decoded.getValue()); + }); + return null; + } + + @Override + public @Nullable OrderedListState bindOrderedList( + String id, StateSpec> spec, Coder elemCoder) { + KvCoder coder = KvCoder.of(elemCoder, InstantCoder.of()); + consumer.set( + (accessor, value) -> { + KV decoded = + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(coder, value.getValue())); + ((OrderedListState) accessor) + .add(TimestampedValue.of(decoded.getKey(), decoded.getValue())); + }); + return null; + } + + @Override + public @Nullable MultimapState bindMultimap( + String id, + StateSpec> spec, + Coder keyCoder, + Coder valueCoder) { + KvCoder coder = KvCoder.of(keyCoder, valueCoder); + consumer.set( + (accessor, value) -> { + KV decoded = + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(coder, value.getValue())); + ((MapState) accessor).put(decoded.getKey(), decoded.getValue()); + }); + return null; + } + + @Override + public @Nullable + CombiningState bindCombining( + String id, + StateSpec> spec, + Coder accumCoder, + CombineFn combineFn) { + consumer.set( + (accessor, value) -> + ((CombiningState) accessor) + .addAccum( + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(accumCoder, value.getValue())))); + return null; + } + + @Override + public @Nullable + CombiningState bindCombiningWithContext( + String id, + StateSpec> spec, + Coder accumCoder, + CombineFnWithContext combineFn) { + consumer.set( + (accessor, value) -> + ((CombiningState) accessor) + .addAccum( + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.decodeFromByteArray(accumCoder, value.getValue())))); + return null; + } + + @Override + public @Nullable WatermarkHoldState bindWatermark( + String id, StateSpec spec, TimestampCombiner timestampCombiner) { + return null; + } + }; + } + + @VisibleForTesting + static StateBinder createStateReaderBinder( + AtomicReference>> res) { + + return new StateBinder() { + @Override + public @Nullable ValueState bindValue( + String id, StateSpec> spec, Coder coder) { + res.set( + (accessor, key) -> { + T value = ((ValueState) accessor).read(); + if (value != null) { + byte[] bytes = + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.encodeToByteArray(coder, value)); + return Collections.singletonList(new StateValue(key, id, bytes)); + } + return Collections.emptyList(); + }); + return null; + } + + @Override + public @Nullable BagState bindBag( + String id, StateSpec> spec, Coder elemCoder) { + res.set( + (accessor, key) -> + Iterables.transform( + ((BagState) accessor).read(), + v -> + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.encodeToByteArray(elemCoder, v))))); + return null; + } + + @Override + public @Nullable SetState bindSet( + String id, StateSpec> spec, Coder elemCoder) { + res.set( + (accessor, key) -> + Iterables.transform( + ((SetState) accessor).read(), + v -> + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.encodeToByteArray(elemCoder, v))))); + return null; + } + + @Override + public @Nullable MapState bindMap( + String id, + StateSpec> spec, + Coder mapKeyCoder, + Coder mapValueCoder) { + KvCoder coder = KvCoder.of(mapKeyCoder, mapValueCoder); + res.set( + (accessor, key) -> + Iterables.transform( + ((MapState) accessor).entries().read(), + v -> + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> + CoderUtils.encodeToByteArray( + coder, KV.of(v.getKey(), v.getValue())))))); + return null; + } + + @Override + public @Nullable OrderedListState bindOrderedList( + String id, StateSpec> spec, Coder elemCoder) { + KvCoder coder = KvCoder.of(elemCoder, InstantCoder.of()); + res.set( + (accessor, key) -> + Iterables.transform( + ((OrderedListState) accessor).read(), + v -> + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> + CoderUtils.encodeToByteArray( + coder, KV.of(v.getValue(), v.getTimestamp())))))); + return null; + } + + @Override + public @Nullable MultimapState bindMultimap( + String id, + StateSpec> spec, + Coder keyCoder, + Coder valueCoder) { + KvCoder coder = KvCoder.of(keyCoder, valueCoder); + res.set( + (accessor, key) -> + Iterables.transform( + ((MultimapState) accessor).entries().read(), + v -> + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> + CoderUtils.encodeToByteArray( + coder, KV.of(v.getKey(), v.getValue())))))); + return null; + } + + @Override + public @Nullable + CombiningState bindCombining( + String id, + StateSpec> spec, + Coder accumCoder, + CombineFn combineFn) { + res.set( + (accessor, key) -> { + AccumT accum = ((CombiningState) accessor).getAccum(); + return Collections.singletonList( + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.encodeToByteArray(accumCoder, accum)))); + }); + return null; + } + + @Override + public @Nullable + CombiningState bindCombiningWithContext( + String id, + StateSpec> spec, + Coder accumCoder, + CombineFnWithContext combineFn) { + res.set( + (accessor, key) -> { + AccumT accum = ((CombiningState) accessor).getAccum(); + return Collections.singletonList( + new StateValue( + key, + id, + ExceptionUtils.uncheckedFactory( + () -> CoderUtils.encodeToByteArray(accumCoder, accum)))); + }); + return null; + } + + @Override + public @Nullable WatermarkHoldState bindWatermark( + String id, StateSpec spec, TimestampCombiner timestampCombiner) { + return null; + } + }; + } + + private static class TimestampedOutputReceiver implements OutputReceiver { + + private final OutputReceiver parentReceiver; + private final Instant elementTimestamp; + + public TimestampedOutputReceiver(OutputReceiver parentReceiver, Instant timestamp) { + this.parentReceiver = parentReceiver; + this.elementTimestamp = timestamp; + } + + @Override + public void output(T output) { + outputWithTimestamp(output, elementTimestamp); + } + + @Override + public void outputWithTimestamp(T output, Instant timestamp) { + parentReceiver.outputWithTimestamp(output, timestamp); + } + } + + public static class MethodParameterArrayElement implements ArgumentLoader { + private final ParameterDescription parameterDescription; + private final int index; + + public MethodParameterArrayElement(ParameterDescription parameterDescription, int index) { + this.parameterDescription = parameterDescription; + this.index = index; + } + + @Override + public StackManipulation toStackManipulation( + ParameterDescription target, Assigner assigner, Assigner.Typing typing) { + StackManipulation stackManipulation = + new StackManipulation.Compound( + MethodVariableAccess.load(this.parameterDescription), + IntegerConstant.forValue(this.index), + ArrayAccess.of(this.parameterDescription.getType().getComponentType()).load(), + assigner.assign( + this.parameterDescription.getType().getComponentType(), + target.getType(), + Typing.DYNAMIC)); + if (!stackManipulation.isValid()) { + throw new IllegalStateException( + "Cannot assign " + + this.parameterDescription.getType().getComponentType() + + " to " + + target); + } else { + return stackManipulation; + } + } + } + + public static class ArrayArgumentProvider implements Factory, ArgumentProvider { + + public InstrumentedType prepare(InstrumentedType instrumentedType) { + return instrumentedType; + } + + public ArgumentProvider make(Implementation.Target implementationTarget) { + return this; + } + + public List resolve( + MethodDescription instrumentedMethod, MethodDescription invokedMethod) { + + ParameterDescription desc = instrumentedMethod.getParameters().get(1); + List res = new ArrayList<>(invokedMethod.getParameters().size()); + for (int i = 0; i < invokedMethod.getParameters().size(); ++i) { + res.add(new MethodParameterArrayElement(desc, i)); + } + return res; + } + } + + public interface MethodInvoker { + + static MethodInvoker of(Method method, ByteBuddy buddy) + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + return getInvoker(method, buddy); + } + + R invoke(T _this, Object[] args); + } + + public interface VoidMethodInvoker { + + static VoidMethodInvoker of(Method method, ByteBuddy buddy) + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + return getInvoker(method, buddy); + } + + void invoke(T _this, Object[] args); + } + + private static T getInvoker(Method method, ByteBuddy buddy) + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + Class declaringClass = method.getDeclaringClass(); + Class superClass = fromDeclaringClass(declaringClass); + String methodName = method.getName(); + Type returnType = method.getGenericReturnType(); + Generic implement = + returnType.equals(void.class) + ? Builder.parameterizedType(VoidMethodInvoker.class, declaringClass).build() + : Builder.parameterizedType(MethodInvoker.class, declaringClass, returnType).build(); + ClassLoadingStrategy strategy = ByteBuddyUtils.getClassLoadingStrategy(superClass); + String subclassName = declaringClass.getName() + "$" + methodName + "Invoker"; + try { + @SuppressWarnings("unchecked") + Class loaded = (Class) superClass.getClassLoader().loadClass(subclassName); + return newInstance(loaded); + } catch (Exception ex) { + // define the class + } + @SuppressWarnings("unchecked") + Class cls = + (Class) + buddy + .subclass(superClass) + .implement(implement) + .name(subclassName) + .defineMethod("invoke", returnType, Visibility.PUBLIC) + .withParameters(declaringClass, Object[].class) + .intercept( + MethodCall.invoke(method).onArgument(0).with(new ArrayArgumentProvider())) + .make() + .load(null, strategy) + .getLoaded(); + + return newInstance(cls); + } + + private static Class fromDeclaringClass(Class cls) { + if (cls.getEnclosingClass() != null && !hasDefaultConstructor(cls)) { + return fromDeclaringClass(cls.getEnclosingClass()); + } + return cls; + } + + private static boolean hasDefaultConstructor(Class cls) { + return Arrays.stream(cls.getDeclaredConstructors()).anyMatch(c -> c.getParameterCount() == 0); + } + + private static T newInstance(Class cls) + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + return cls.getDeclaredConstructor().newInstance(); + } + + private MethodCallUtils() {} +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java new file mode 100644 index 000000000..ebb7a1ae7 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java @@ -0,0 +1,156 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static cz.o2.proxima.beam.util.state.ExternalStateExpander.bagStateFromInputType; +import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; + +import cz.o2.proxima.core.util.Pair; +import cz.o2.proxima.internal.com.google.common.base.MoreObjects; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import cz.o2.proxima.internal.com.google.common.collect.Sets; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.annotation.AnnotationDescription.Builder; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.StateId; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.checkerframework.checker.nullness.qual.Nullable; + +interface OnWindowParameterExpander { + + static OnWindowParameterExpander of( + ParameterizedType inputType, + Method processElement, + @Nullable Method onWindowExpiration, + TupleTag mainTag, + Type outputType) { + + final LinkedHashMap> processArgs = extractArgs(processElement); + final LinkedHashMap> onWindowArgs = + extractArgs(onWindowExpiration); + final List> wrapperArgList = + createWrapperArgList(processArgs, onWindowArgs); + final LinkedHashMap> wrapperArgs = + createWrapperArgs(inputType, wrapperArgList); + final List>, Object>> processArgsGenerators = + projectArgs(wrapperArgs, processArgs, mainTag, outputType); + final List>, Object>> windowArgsGenerators = + projectArgs(wrapperArgs, onWindowArgs, mainTag, outputType); + + return new OnWindowParameterExpander() { + @Override + public List> getWrapperArgs() { + return new ArrayList<>(wrapperArgs.values()); + } + + @Override + public Object[] getProcessElementArgs( + TimestampedValue> input, Object[] wrapperArgs) { + return fromGenerators(input, processArgsGenerators, wrapperArgs); + } + + @Override + public Object[] getOnWindowExpirationArgs(Object[] wrapperArgs) { + return fromGenerators(null, windowArgsGenerators, wrapperArgs); + } + }; + } + + static List> createWrapperArgList( + LinkedHashMap> processArgs, + LinkedHashMap> onWindowArgs) { + + Set union = new HashSet<>(Sets.union(processArgs.keySet(), onWindowArgs.keySet())); + // @Element is not supported by @OnWindowExpiration + union.remove(TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Element.class).build())); + return union.stream() + .map( + t -> { + Pair processPair = processArgs.get(t); + Pair windowPair = onWindowArgs.get(t); + Type argType = MoreObjects.firstNonNull(processPair, windowPair).getSecond(); + Annotation processAnnotation = + Optional.ofNullable(processPair).map(Pair::getFirst).orElse(null); + Annotation windowAnnotation = + Optional.ofNullable(windowPair).map(Pair::getFirst).orElse(null); + Preconditions.checkState( + processPair == null + || windowPair == null + || processAnnotation == windowAnnotation + || processAnnotation.equals(windowAnnotation)); + return Pair.of(processAnnotation, argType); + }) + .collect(Collectors.toList()); + } + + static LinkedHashMap> createWrapperArgs( + ParameterizedType inputType, List> wrapperArgList) { + + LinkedHashMap> res = new LinkedHashMap<>(); + wrapperArgList.stream() + .map( + p -> + Pair.of( + p.getFirst() == null ? TypeId.of(p.getSecond()) : TypeId.of(p.getFirst()), + Pair.of( + p.getFirst() != null + ? (AnnotationDescription) + AnnotationDescription.ForLoadedAnnotation.of(p.getFirst()) + : null, + (TypeDefinition) + TypeDescription.Generic.Builder.of(p.getSecond()).build()))) + .forEachOrdered(p -> res.put(p.getFirst(), p.getSecond())); + + // add @StateId for buffer + AnnotationDescription buffer = + Builder.ofType(StateId.class) + .define("value", ExternalStateExpander.EXPANDER_BUF_STATE_NAME) + .build(); + res.put(TypeId.of(buffer), Pair.of(buffer, bagStateFromInputType(inputType))); + return res; + } + + /** + * Get arguments that must be declared by wrapper's call for both {@code @}ProcessElement and + * {@code @}OnWindowExpiration be callable. + */ + List> getWrapperArgs(); + + /** + * Get parameters that should be passed to {@code @}ProcessElement from wrapper's + * {@code @}OnWindowExpiration + */ + Object[] getProcessElementArgs(TimestampedValue> input, Object[] wrapperArgs); + + /** Get parameters that should be passed to {@code @}OnWindowExpiration from wrapper's call. */ + Object[] getOnWindowExpirationArgs(Object[] wrapperArgs); +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java new file mode 100644 index 000000000..02c547d80 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java @@ -0,0 +1,256 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; +import static cz.o2.proxima.beam.util.state.MethodCallUtils.projectArgs; + +import cz.o2.proxima.core.functional.BiConsumer; +import cz.o2.proxima.core.functional.UnaryFunction; +import cz.o2.proxima.core.util.Pair; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Predicate; +import lombok.extern.slf4j.Slf4j; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.annotation.AnnotationDescription.ForLoadedAnnotation; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.description.type.TypeDescription.ForLoadedType; +import net.bytebuddy.description.type.TypeDescription.Generic; +import net.bytebuddy.description.type.TypeDescription.Generic.Builder; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.StateId; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Instant; + +interface ProcessElementParameterExpander { + + static ProcessElementParameterExpander of( + DoFn doFn, + Method processElement, + ParameterizedType inputType, + TupleTag mainTag, + Type outputType, + Instant stateWriteInstant) { + + final LinkedHashMap> processArgs = extractArgs(processElement); + final LinkedHashMap> wrapperArgs = + createWrapperArgs(inputType, outputType, processArgs.values()); + final List>, Object>> processArgsGenerators = + projectArgs(wrapperArgs, processArgs, mainTag, outputType); + + return new ProcessElementParameterExpander() { + @Override + public List> getWrapperArgs() { + return new ArrayList<>(wrapperArgs.values()); + } + + @Override + public Object[] getProcessElementArgs(Object[] wrapperArgs) { + return fromGenerators(processArgsGenerators, wrapperArgs); + } + + @Override + public UnaryFunction getProcessFn() { + return createProcessFn(wrapperArgs, doFn, processElement, stateWriteInstant); + } + }; + } + + /** Get arguments that must be declared by wrapper's call. */ + List> getWrapperArgs(); + + /** + * Get parameters that should be passed to {@code @}ProcessElement from wrapper's + * {@code @}ProcessElement + */ + Object[] getProcessElementArgs(Object[] wrapperArgs); + + /** Get function to process elements and delegate to original DoFn. */ + UnaryFunction getProcessFn(); + + private static UnaryFunction createProcessFn( + LinkedHashMap> wrapperArgs, + DoFn doFn, + Method method, + Instant stateWriteInstant) { + + Map> stateUpdaterMap = getStateUpdaters(doFn); + return new ProcessFn(stateWriteInstant, wrapperArgs, method, stateUpdaterMap); + } + + private static int findParameter(Collection args, Predicate predicate) { + int i = 0; + for (TypeId t : args) { + if (predicate.test(t)) { + return i; + } + i++; + } + return -1; + } + + static LinkedHashMap> createWrapperArgs( + ParameterizedType inputType, + Type outputType, + Collection> processArgs) { + + LinkedHashMap> res = new LinkedHashMap<>(); + processArgs.stream() + .map(p -> transformProcessArg(inputType, p)) + .filter(p -> !p.getFirst().isOutput(outputType) && !p.getFirst().isTimestamp()) + .forEachOrdered(p -> res.put(p.getFirst(), p.getSecond())); + + // add @Timestamp + AnnotationDescription timestampAnnotation = + AnnotationDescription.Builder.ofType(DoFn.Timestamp.class).build(); + res.put( + TypeId.of(timestampAnnotation), + Pair.of(timestampAnnotation, TypeDescription.ForLoadedType.of(Instant.class))); + // add @TimerId for flush timer + AnnotationDescription timerAnnotation = + AnnotationDescription.Builder.ofType(DoFn.TimerId.class) + .define("value", ExternalStateExpander.EXPANDER_TIMER_NAME) + .build(); + res.put( + TypeId.of(timerAnnotation), + Pair.of(timerAnnotation, TypeDescription.ForLoadedType.of(Timer.class))); + + // add @StateId for finished buffer + AnnotationDescription finishedAnnotation = + AnnotationDescription.Builder.ofType(DoFn.StateId.class) + .define("value", ExternalStateExpander.EXPANDER_FLUSH_STATE_NAME) + .build(); + res.put( + TypeId.of(finishedAnnotation), + Pair.of( + finishedAnnotation, + TypeDescription.Generic.Builder.parameterizedType(ValueState.class, Instant.class) + .build())); + + // add @StateId for buffer + AnnotationDescription stateAnnotation = + AnnotationDescription.Builder.ofType(StateId.class) + .define("value", ExternalStateExpander.EXPANDER_BUF_STATE_NAME) + .build(); + res.put( + TypeId.of(stateAnnotation), + Pair.of(stateAnnotation, ExternalStateExpander.bagStateFromInputType(inputType))); + + // add MultiOutputReceiver + TypeDescription receiver = ForLoadedType.of(MultiOutputReceiver.class); + res.put(TypeId.of(receiver), Pair.of(null, receiver)); + return res; + } + + static Pair> transformProcessArg( + ParameterizedType inputType, Pair p) { + + TypeId typeId = p.getFirst() == null ? TypeId.of(p.getSecond()) : TypeId.of(p.getFirst()); + AnnotationDescription annotation = + p.getFirst() != null ? ForLoadedAnnotation.of(p.getFirst()) : null; + Generic parameterType = Builder.of(p.getSecond()).build(); + if (typeId.equals( + TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Element.class).build()))) { + parameterType = getWrapperInputType(inputType); + } + return Pair.of(typeId, Pair.of(annotation, parameterType)); + } + + @Slf4j + class ProcessFn implements UnaryFunction { + private final int elementPos; + private final Instant stateWriteInstant; + private final LinkedHashMap> wrapperArgs; + private final Method method; + private final Map> stateUpdaterMap; + + public ProcessFn( + Instant stateWriteInstant, + LinkedHashMap> wrapperArgs, + Method method, + Map> stateUpdaterMap) { + + this.elementPos = findParameter(wrapperArgs.keySet(), TypeId::isElement); + this.stateWriteInstant = stateWriteInstant; + this.wrapperArgs = wrapperArgs; + this.method = method; + this.stateUpdaterMap = stateUpdaterMap; + Preconditions.checkState(elementPos >= 0, "Missing @Element annotation on method %s", method); + } + + @Override + public Boolean apply(Object[] args) { + @SuppressWarnings("unchecked") + KV> elem = (KV>) args[elementPos]; + Instant ts = (Instant) args[args.length - 5]; + Timer flushTimer = (Timer) args[args.length - 4]; + @SuppressWarnings("unchecked") + ValueState finishedState = (ValueState) args[args.length - 3]; + boolean isState = Objects.requireNonNull(elem.getValue(), "elem").isState(); + if (isState) { + StateValue state = elem.getValue().getState(); + String stateName = state.getName(); + // find state accessor + int statePos = findParameter(wrapperArgs.keySet(), a -> a.isState(stateName)); + Preconditions.checkArgument( + statePos < method.getParameterCount(), "Missing state accessor for %s", stateName); + Object stateAccessor = args[statePos]; + // find declaration of state to find coder + BiConsumer updater = stateUpdaterMap.get(stateName); + Preconditions.checkArgument( + updater != null, "Missing updater for state %s in %s", stateName, stateUpdaterMap); + updater.accept(stateAccessor, state); + return false; + } + Instant nextFlush = finishedState.read(); + if (nextFlush == null) { + // set the initial timer + flushTimer.set(stateWriteInstant); + } + boolean shouldBuffer = + nextFlush == null /* we have not finished reading state */ + || nextFlush.isBefore(ts) /* the timestamp if after next flush */; + if (shouldBuffer) { + log.debug("Buffering element {} at {} with nextFlush {}", elem, ts, nextFlush); + // store to state + @SuppressWarnings("unchecked") + BagState>> buffer = + (BagState>>) args[args.length - 2]; + buffer.add(TimestampedValue.of(KV.of(elem.getKey(), elem.getValue().getInput()), ts)); + return false; + } + return true; + } + } +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/StateOrInput.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/StateOrInput.java new file mode 100644 index 000000000..6e07be807 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/StateOrInput.java @@ -0,0 +1,98 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import cz.o2.proxima.beam.util.state.StateValue.StateValueCoder; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Objects; +import lombok.Value; +import org.apache.beam.sdk.coders.ByteCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CustomCoder; +import org.checkerframework.checker.nullness.qual.Nullable; + +@Value +public class StateOrInput { + + public static class StateOrInputCoder extends CustomCoder> { + + private final ByteCoder byteCoder = ByteCoder.of(); + private final StateValueCoder stateCoder = StateValue.coder(); + private final Coder inputCoder; + + private StateOrInputCoder(Coder inputCoder) { + this.inputCoder = inputCoder; + } + + @Override + public void encode(StateOrInput value, OutputStream outStream) throws IOException { + byteCoder.encode(value.getTag(), outStream); + if (value.isState()) { + stateCoder.encode(value.getState(), outStream); + } else { + inputCoder.encode(value.getInput(), outStream); + } + } + + @Override + public StateOrInput decode(InputStream inStream) throws IOException { + byte tag = byteCoder.decode(inStream); + if (tag == 0) { + return new StateOrInput<>(tag, stateCoder.decode(inStream), null); + } + return new StateOrInput<>(tag, null, inputCoder.decode(inStream)); + } + + @Override + public int hashCode() { + return Objects.hash(byteCoder, stateCoder, inputCoder); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof StateOrInputCoder)) { + return false; + } + StateOrInputCoder other = (StateOrInputCoder) obj; + return other.inputCoder.equals(this.inputCoder); + } + } + + public static StateOrInputCoder coder(Coder inputCoder) { + return new StateOrInputCoder<>(inputCoder); + } + + public static StateOrInput state(StateValue state) { + return new StateOrInput<>((byte) 0, state, null); + } + + public static StateOrInput input(T input) { + return new StateOrInput<>((byte) 1, null, input); + } + + byte tag; + @Nullable StateValue state; + @Nullable T input; + + boolean isState() { + return tag == 0; + } +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/StateValue.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/StateValue.java new file mode 100644 index 000000000..e4f4f2df7 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/StateValue.java @@ -0,0 +1,55 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import lombok.Value; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; + +@Value +public class StateValue { + public static class StateValueCoder extends CustomCoder { + private static final StateValueCoder INSTANCE = new StateValueCoder(); + private static final ByteArrayCoder BAC = ByteArrayCoder.of(); + private static final StringUtf8Coder SUC = StringUtf8Coder.of(); + + private StateValueCoder() {} + + @Override + public void encode(StateValue value, OutputStream outStream) throws IOException { + BAC.encode(value.getKey(), outStream); + SUC.encode(value.getName(), outStream); + BAC.encode(value.getValue(), outStream); + } + + @Override + public StateValue decode(InputStream inStream) throws IOException { + return new StateValue(BAC.decode(inStream), SUC.decode(inStream), BAC.decode(inStream)); + } + } + + public static StateValueCoder coder() { + return StateValueCoder.INSTANCE; + } + + byte[] key; + String name; + byte[] value; +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java new file mode 100644 index 000000000..028d65ba8 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java @@ -0,0 +1,115 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import cz.o2.proxima.internal.com.google.common.base.MoreObjects; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.StateId; + +class TypeId { + + private static final TypeId TIMESTAMP_TYPE = + TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Timestamp.class).build()); + + private static final TypeId ELEMENT_TYPE = + TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Element.class).build()); + + private static final TypeId MULTI_OUTPUT_TYPE = + TypeId.of(TypeDescription.ForLoadedType.of(DoFn.MultiOutputReceiver.class)); + + public static TypeId of(Annotation annotation) { + return new TypeId(annotation.toString()); + } + + public static TypeId of(AnnotationDescription annotationDescription) { + return new TypeId(annotationDescription.toString()); + } + + public static TypeId of(Type type) { + Preconditions.checkArgument(!(type instanceof Annotation)); + return new TypeId(type.getTypeName()); + } + + public static TypeId of(TypeDefinition definition) { + return new TypeId(definition.getTypeName()); + } + + private final String stringId; + + private TypeId(String stringId) { + this.stringId = stringId; + } + + @Override + public int hashCode() { + return stringId.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof TypeId)) { + return false; + } + TypeId other = (TypeId) obj; + return this.stringId.equals(other.stringId); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("id", stringId).toString(); + } + + public boolean isElement() { + return equals(ELEMENT_TYPE); + } + + public boolean isState(String stateName) { + return equals( + TypeId.of( + AnnotationDescription.Builder.ofType(StateId.class) + .define("value", stateName) + .build())); + } + + public boolean isTimestamp() { + return equals(TIMESTAMP_TYPE); + } + + public boolean isOutputReceiver(Type outputType) { + return equals( + TypeId.of( + TypeDescription.Generic.Builder.parameterizedType(OutputReceiver.class, outputType) + .build())); + } + + public boolean isOutput(Type outputType) { + return isOutputReceiver(outputType) || isMultiOutput(); + } + + public boolean isMultiOutput() { + return equals(MULTI_OUTPUT_TYPE); + } +} diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java new file mode 100644 index 000000000..e2da7fd36 --- /dev/null +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java @@ -0,0 +1,431 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static com.mongodb.internal.connection.tlschannel.util.Util.assertTrue; +import static org.junit.Assert.assertEquals; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import cz.o2.proxima.core.util.SerializableScopedValue; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.UUID; +import org.apache.beam.runners.direct.DirectRunner; +import org.apache.beam.runners.flink.FlinkRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reify; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.jetbrains.annotations.NotNull; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class ExternalStateExpanderTest { + + @Parameters + public static List>> params() { + return Arrays.asList(DirectRunner.class, FlinkRunner.class); + } + + @Parameter public Class> runner; + + @Test + public void testSimpleExpand() { + Pipeline pipeline = createPipeline(); + PCollection inputs = pipeline.apply(Create.of("1", "2", "3")); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + PCollection count = withKeys.apply(ParDo.of(getSumFn())); + PAssert.that(count).containsInAnyOrder(2L, 4L); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + new Instant(0), + ign -> BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + expanded.run(); + } + + @Test + public void testSimpleExpandMultiOutput() { + Pipeline pipeline = createPipeline(); + PCollection inputs = pipeline.apply(Create.of("1", "2", "3")); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + TupleTag mainTag = new TupleTag<>(); + PCollection count = + withKeys + .apply(ParDo.of(getSumFn()).withOutputTags(mainTag, TupleTagList.empty())) + .get(mainTag); + PAssert.that(count).containsInAnyOrder(2L, 4L); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + new Instant(0), + ign -> BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + expanded.run(); + } + + @Test + public void testCompositeExpand() { + PTransform, PCollection> transform = + new PTransform<>() { + @Override + public PCollection expand(PCollection input) { + PCollection> withKeys = + input.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + return withKeys.apply(ParDo.of(getSumFn())); + } + }; + Pipeline pipeline = createPipeline(); + PCollection inputs = pipeline.apply(Create.of("1", "2", "3")); + PCollection count = inputs.apply(transform); + PAssert.that(count).containsInAnyOrder(2L, 4L); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + new Instant(0), + ign -> BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + expanded.run(); + } + + @Test + public void testSimpleExpandWithInitialState() throws CoderException { + Pipeline pipeline = createPipeline(); + PCollection inputs = pipeline.apply(Create.of("3", "4")); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + PCollection count = withKeys.apply("sum", ParDo.of(getSumFn())); + PAssert.that(count).containsInAnyOrder(6L, 4L); + VarIntCoder intCoder = VarIntCoder.of(); + VarLongCoder longCoder = VarLongCoder.of(); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.of( + KV.of( + "sum/ParMultiDo(Anonymous)", + new StateValue( + CoderUtils.encodeToByteArray(intCoder, 0), + "sum", + CoderUtils.encodeToByteArray(longCoder, 2L))), + KV.of( + "sum/ParMultiDo(Anonymous)", + new StateValue( + CoderUtils.encodeToByteArray(intCoder, 1), + "sum", + CoderUtils.encodeToByteArray(longCoder, 1L)))) + .withCoder(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + new Instant(0), + current -> BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + expanded.run(); + } + + @Test + public void testSimpleExpandWithStateStore() throws CoderException { + Pipeline pipeline = createPipeline(); + Instant now = new Instant(0); + PCollection inputs = + pipeline.apply( + Create.timestamped(TimestampedValue.of("1", now), TimestampedValue.of("2", now))); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + PCollection count = withKeys.apply("sum", ParDo.of(getSumFn())); + PAssert.that(count).containsInAnyOrder(1L, 2L); + PriorityQueue>> states = + // compare StateValue by toString, lombok's @Value has stable .toString() in this case + new PriorityQueue<>(Comparator.comparing(e -> e.getValue().getValue().toString())); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + now, + current -> current.equals(now) ? now.plus(1) : BoundedWindow.TIMESTAMP_MAX_VALUE, + collectStates(states)); + expanded.run(); + assertEquals(2, states.size()); + TimestampedValue> first = states.poll(); + assertEquals(new Instant(1), first.getTimestamp()); + assertTrue(first.getValue().getKey().startsWith("sum")); + assertEquals( + 0, + (int) + CoderUtils.decodeFromByteArray(VarIntCoder.of(), first.getValue().getValue().getKey())); + TimestampedValue> second = states.poll(); + assertEquals(new Instant(1), second.getTimestamp()); + assertTrue(second.getValue().getKey().startsWith("sum")); + assertEquals( + 1, + (int) + CoderUtils.decodeFromByteArray( + VarIntCoder.of(), second.getValue().getValue().getKey())); + } + + @Test + public void testStateWithElementEarly() throws CoderException { + Pipeline pipeline = createPipeline(); + Instant now = new Instant(0); + PCollection inputs = + pipeline.apply( + TestStream.create(StringUtf8Coder.of()) + // the second timestamped value MUST not be part of the state produced at 1 + .addElements(TimestampedValue.of("1", now), TimestampedValue.of("3", now.plus(2))) + .advanceWatermarkTo(new Instant(1)) + .advanceWatermarkToInfinity()); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + PCollection count = withKeys.apply("sum", ParDo.of(getSumFn())); + PAssert.that(count).containsInAnyOrder(4L); + List>> states = new ArrayList<>(); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + now, + current -> current.equals(now) ? now.plus(1) : BoundedWindow.TIMESTAMP_MAX_VALUE, + collectStates(states)); + expanded.run(); + assertEquals(1, states.size()); + TimestampedValue> first = states.get(0); + assertEquals( + 1L, + (long) + CoderUtils.decodeFromByteArray( + VarLongCoder.of(), first.getValue().getValue().getValue())); + } + + @Test + public void testBufferedTimestampInject() { + testTimestampInject(false); + } + + @Test + public void testBufferedTimestampInjectToMultiOutput() { + testTimestampInject(true); + } + + private void testTimestampInject(boolean multiOutput) { + Pipeline pipeline = createPipeline(); + Instant now = new Instant(0); + PCollection inputs = + pipeline.apply( + TestStream.create(StringUtf8Coder.of()) + // the second timestamped value MUST not be part of the state produced at 1 + .addElements(TimestampedValue.of("1", now)) + .advanceWatermarkTo(new Instant(0)) + .addElements(TimestampedValue.of("3", now.plus(10))) + .advanceWatermarkTo(new Instant(1)) + .advanceWatermarkToInfinity()); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + PCollection> outputs = + withKeys.apply("sum", bufferWithTimestamp(multiOutput)); + PAssert.that(outputs) + .containsInAnyOrder( + TimestampedValue.of("KV{1, 1}@0", now), + TimestampedValue.of("KV{1, 3}@10", now.plus(10))); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + now, + current -> + current.isBefore(now.plus(2)) ? current.plus(1) : BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + expanded.run(); + } + + private static PTransform>, PDone> collectStates( + Collection>> states) { + + String id = UUID.randomUUID().toString(); + final SerializableScopedValue>>> + val = new SerializableScopedValue<>(id, states); + return new PTransform<>() { + @Override + public PDone expand(PCollection> input) { + input.apply( + ParDo.of( + new DoFn, Void>() { + @ProcessElement + public void process(@Element KV elem, @Timestamp Instant ts) { + Collection>> m = val.get(id); + synchronized (m) { + m.add(TimestampedValue.of(elem, ts)); + } + } + })); + return PDone.in(input.getPipeline()); + } + }; + } + + private static PTransform>, PCollection>> + bufferWithTimestamp(boolean withMultiOutput) { + + if (withMultiOutput) { + return new PTransform<>() { + @Override + public PCollection> expand( + PCollection> input) { + TupleTag mainOutput = new TupleTag<>() {}; + return input + .apply( + ParDo.of( + new DoFn, String>() { + // just declare state to be expanded + @StateId("state") + private final StateSpec> buf = StateSpecs.value(); + + @ProcessElement + public void process( + @Element KV elem, + @Timestamp Instant ts, + MultiOutputReceiver output) { + + output.get(mainOutput).output(elem + "@" + ts.getMillis()); + } + }) + .withOutputTags(mainOutput, TupleTagList.empty())) + .get(mainOutput) + .apply(Reify.timestamps()); + } + }; + } + + return new PTransform<>() { + @Override + public PCollection> expand(PCollection> input) { + return input + .apply( + ParDo.of( + new DoFn, String>() { + // just declare state to be expanded + @StateId("state") + private final StateSpec> buf = StateSpecs.value(); + + @ProcessElement + public void process( + @Element KV elem, + @Timestamp Instant ts, + OutputReceiver output) { + + output.output(elem + "@" + ts.getMillis()); + } + })) + .apply(Reify.timestamps()); + } + }; + } + + private @NotNull Pipeline createPipeline() { + PipelineOptions opts = PipelineOptionsFactory.create(); + opts.setRunner(runner); + return Pipeline.create(opts); + } + + private static DoFn, Long> getSumFn() { + return new DoFn, Long>() { + @StateId("sum") + private final StateSpec> spec = StateSpecs.value(); + + @ProcessElement + public void process( + OutputReceiver ignored, + @Element KV element, + @StateId("sum") ValueState sum) { + + Preconditions.checkArgument(ignored instanceof OutputReceiver); + long current = MoreObjects.firstNonNull(sum.read(), 0L); + sum.write(current + Integer.parseInt(element.getValue())); + } + + @OnWindowExpiration + public void onExpiration(@StateId("sum") ValueState sum, OutputReceiver output) { + Long value = sum.read(); + if (value != null) { + output.output(value); + } + } + }; + } + + private PTransform>, PDone> dummy() { + return new PTransform<>() { + @Override + public PDone expand(PCollection> input) { + input.apply(MapElements.into(TypeDescriptors.voids()).via(a -> null)); + return PDone.in(input.getPipeline()); + } + }; + } +} diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java new file mode 100644 index 000000000..88ddeeb7f --- /dev/null +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java @@ -0,0 +1,182 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static com.mongodb.internal.connection.tlschannel.util.Util.assertTrue; +import static org.junit.Assert.assertEquals; + +import cz.o2.proxima.beam.util.state.MethodCallUtils.MethodInvoker; +import cz.o2.proxima.beam.util.state.MethodCallUtils.VoidMethodInvoker; +import cz.o2.proxima.core.functional.BiConsumer; +import cz.o2.proxima.core.functional.Consumer; +import cz.o2.proxima.core.functional.Factory; +import cz.o2.proxima.core.functional.UnaryFunction; +import cz.o2.proxima.core.util.ExceptionUtils; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import net.bytebuddy.ByteBuddy; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.state.StateBinder; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.junit.Test; + +public class MethodCallUtilsTest { + + @Test + public void testBinders() { + AtomicReference>> tmp = new AtomicReference<>(); + AtomicReference> tmp2 = new AtomicReference<>(); + testBinder(MethodCallUtils.createStateReaderBinder(tmp)); + testBinder(MethodCallUtils.createUpdaterBinder(tmp2)); + } + + @Test + public void testMethodInvoker() + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + testMethodInvokerWith(Sum::new, Integer::valueOf, int.class); + } + + @Test + public void testMethodInvokerLong() + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + testMethodInvokerWith(Sum2::new, Long::valueOf, Long.class); + } + + @Test + public void testMethodInvokerWithVoid() + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + ByteBuddy buddy = new ByteBuddy(); + SumCollect s = new SumCollect(); + Method method = s.getClass().getDeclaredMethod("apply", int.class, int.class, Consumer.class); + VoidMethodInvoker invoker = VoidMethodInvoker.of(method, buddy); + List list = new ArrayList<>(); + Consumer c = list::add; + invoker.invoke(s, new Object[] {1, 2, c}); + assertEquals(3, (int) list.get(0)); + + long start = System.nanoTime(); + Consumer ign = dummy -> {}; + for (int i = 0; i < 1_000_000; i++) { + invoker.invoke(s, new Object[] {1, 2, ign}); + } + long duration = System.nanoTime() - start; + assertTrue(duration < 1_000_000_000); + } + + @Test + public void testNonStaticSubclass() + throws InvocationTargetException, + NoSuchMethodException, + InstantiationException, + IllegalAccessException { + Sum s = + new Sum() { + final MethodInvoker invoker = + MethodInvoker.of( + ExceptionUtils.uncheckedFactory( + () -> Delegate.class.getDeclaredMethod("apply", int.class, int.class)), + new ByteBuddy()); + + class Delegate { + Integer apply(int a, int b) { + return a + b; + } + } + + @Override + public Integer apply(int a, int b) { + return invoker.invoke(new Delegate(), new Object[] {a, b}); + } + }; + assertEquals(3, (int) s.apply(1, 2)); + } + + void testMethodInvokerWith( + Factory instanceFactory, UnaryFunction valueFactory, Class paramType) + throws NoSuchMethodException, + InvocationTargetException, + InstantiationException, + IllegalAccessException { + + ByteBuddy buddy = new ByteBuddy(); + T s = instanceFactory.apply(); + Method method = s.getClass().getDeclaredMethod("apply", paramType, paramType); + MethodInvoker invoker = MethodInvoker.of(method, buddy); + assertEquals( + valueFactory.apply(3), + invoker.invoke( + instanceFactory.apply(), new Object[] {valueFactory.apply(1), valueFactory.apply(2)})); + + long start = System.nanoTime(); + for (int i = 0; i < 1_000_000; i++) { + invoker.invoke(s, new Object[] {valueFactory.apply(i), valueFactory.apply(i)}); + } + long duration = System.nanoTime() - start; + assertTrue(duration < 1_000_000_000); + } + + private void testBinder(StateBinder binder) { + List> specs = + Arrays.asList( + StateSpecs.bag(), + StateSpecs.value(), + StateSpecs.map(), + StateSpecs.multimap(), + StateSpecs.combining(org.apache.beam.sdk.transforms.Sum.ofIntegers()), + StateSpecs.orderedList(VarIntCoder.of())); + specs.forEach(s -> testBinder(s, binder)); + } + + private void testBinder(StateSpec s, StateBinder binder) { + s.bind("dummy", binder); + } + + public static class Sum { + public Integer apply(int a, int b) { + return a + b; + } + } + + public static class Sum2 { + public Long apply(Long a, Long b) { + return a + b; + } + } + + public static class SumCollect { + public void apply(int a, int b, Consumer result) { + result.accept(a + b); + } + } +} diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/TypeIdTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/TypeIdTest.java new file mode 100644 index 000000000..0f0fe2de0 --- /dev/null +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/TypeIdTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import cz.o2.proxima.core.util.Optionals; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.Arrays; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.type.TypeDescription; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.values.KV; +import org.junit.Test; + +public class TypeIdTest { + + @Test + public void testAnnotations() { + DoFn doFn = + new DoFn() { + @ProcessElement + public void process( + @Element KV elem, + @StateId("state") ValueState s, + OutputReceiver out) {} + }; + + Method process = + Optionals.get( + Arrays.stream(doFn.getClass().getDeclaredMethods()) + .filter(m -> m.getName().equals("process")) + .findAny()); + Type[] parameterTypes = process.getGenericParameterTypes(); + Annotation[][] annotations = process.getParameterAnnotations(); + assertEqualType( + TypeId.of(annotations[0][0]), + TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Element.class).build())); + assertEqualType( + TypeId.of(annotations[1][0]), + TypeId.of( + AnnotationDescription.Builder.ofType(DoFn.StateId.class) + .define("value", "state") + .build())); + + assertNotEqualType( + TypeId.of(annotations[1][0]), + TypeId.of( + AnnotationDescription.Builder.ofType(DoFn.StateId.class) + .define("value", "state2") + .build())); + + assertEqualType( + TypeId.of(parameterTypes[2]), + TypeId.of( + TypeDescription.Generic.Builder.parameterizedType(OutputReceiver.class, String.class) + .build())); + } + + private void assertNotEqualType(TypeId first, TypeId second) { + assertNotEquals(first, second); + assertNotEquals(first.hashCode(), second.hashCode()); + } + + private void assertEqualType(TypeId first, TypeId second) { + assertEquals(first, second); + assertEquals(first.hashCode(), second.hashCode()); + } +} diff --git a/core/src/main/java/cz/o2/proxima/core/util/SerializableScopedValue.java b/core/src/main/java/cz/o2/proxima/core/util/SerializableScopedValue.java index cb4708da3..831b4e0e8 100644 --- a/core/src/main/java/cz/o2/proxima/core/util/SerializableScopedValue.java +++ b/core/src/main/java/cz/o2/proxima/core/util/SerializableScopedValue.java @@ -23,6 +23,7 @@ import java.util.Objects; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import javax.annotation.Nullable; /** * A value that holds a {@link Serializable} value and scopes its value to given context. @@ -35,20 +36,25 @@ public final class SerializableScopedValue implements Serializable { private static final Map> VALUE_MAP = new ConcurrentHashMap<>(); private final String uuid = UUID.randomUUID().toString(); - private final Factory factory; + private final @Nullable Factory factory; public SerializableScopedValue(Factory what) { this.factory = Objects.requireNonNull(what); VALUE_MAP.putIfAbsent(uuid, new ConcurrentHashMap<>()); } + public SerializableScopedValue(C context, V value) { + this.factory = null; + VALUE_MAP.compute(uuid, (k, v) -> new ConcurrentHashMap<>()).put(context, value); + } + @SuppressWarnings("unchecked") public V get(C context) { return (V) VALUE_MAP.get(uuid).computeIfAbsent(context, t -> cloneOriginal()); } private V cloneOriginal() { - return factory.apply(); + return Objects.requireNonNull(factory).apply(); } /** diff --git a/core/src/test/java/cz/o2/proxima/core/util/SerializableScopedValueTest.java b/core/src/test/java/cz/o2/proxima/core/util/SerializableScopedValueTest.java index 87e2da830..483f4519f 100644 --- a/core/src/test/java/cz/o2/proxima/core/util/SerializableScopedValueTest.java +++ b/core/src/test/java/cz/o2/proxima/core/util/SerializableScopedValueTest.java @@ -33,6 +33,14 @@ public void testSerializable() throws IOException, ClassNotFoundException { TestUtils.assertHashCodeAndEquals(value, other); } + @Test + public void testSerializableWithValue() throws IOException, ClassNotFoundException { + SerializableScopedValue value = new SerializableScopedValue<>(1, 2); + SerializableScopedValue other = TestUtils.assertSerializable(value); + TestUtils.assertHashCodeAndEquals(value, other); + assertEquals(value.get(1), other.get(1)); + } + @Test public void testContextLocality() throws IOException, ClassNotFoundException { BlockingQueue results = new LinkedBlockingDeque<>();