diff --git a/beam/core/build.gradle b/beam/core/build.gradle index a2f754b2a..a533ea8b3 100644 --- a/beam/core/build.gradle +++ b/beam/core/build.gradle @@ -24,6 +24,8 @@ dependencies { api libraries.beam_core implementation "cz.o2.proxima:proxima-vendor:${project.version}" implementation libraries.beam_extensions_kryo + provided libraries.beam_runners_flink + provided libraries.beam_runners_spark testImplementation project(path: ':proxima-core', configuration: 'testsJar') testImplementation project(path: ':proxima-core') testImplementation project(path: ':proxima-direct-core', configuration: 'testsJar') @@ -32,7 +34,6 @@ dependencies { testImplementation project(path: ':proxima-direct-io-kafka', configuration: 'testsJar') testImplementation project(path: ':proxima-scheme-proto-testing') testImplementation libraries.beam_runners_direct - testImplementation libraries.beam_runners_flink testImplementation libraries.beam_sql testImplementation libraries.junit4 testImplementation libraries.hamcrest @@ -51,4 +52,8 @@ protobuf { protoc { artifact = libraries.protoc } } +test { + jvmArgs '-Dsun.io.serialization.extendedDebugInfo=true' +} + publishArtifacts(project, "default") diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java new file mode 100644 index 000000000..2dc8747f6 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java @@ -0,0 +1,117 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util; + +import cz.o2.proxima.core.annotations.Internal; +import cz.o2.proxima.core.util.ExceptionUtils; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.extern.slf4j.Slf4j; +import org.apache.beam.repackaged.core.org.apache.commons.compress.utils.IOUtils; +import org.apache.beam.runners.flink.FlinkPipelineOptions; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.sdk.options.PipelineOptions; + +@Internal +@Slf4j +public class RunnerUtils { + + /** Register given set of jars to runner. */ + public static void registerToPipeline( + PipelineOptions opts, String runnerName, Collection paths) { + log.info("Adding jars {} into classpath for runner {}", paths, runnerName); + List filesToStage = + paths.stream().map(File::getAbsolutePath).collect(Collectors.toList()); + if (runnerName.equals("FlinkRunner")) { + FlinkPipelineOptions flinkOpts = opts.as(FlinkPipelineOptions.class); + flinkOpts.setFilesToStage(addToList(filesToStage, flinkOpts.getFilesToStage())); + } else if (runnerName.equals("SparkRunner")) { + SparkCommonPipelineOptions sparkOpts = opts.as(SparkCommonPipelineOptions.class); + sparkOpts.setFilesToStage(addToList(filesToStage, sparkOpts.getFilesToStage())); + } else { + if (!runnerName.equals("DirectRunner")) { + log.warn( + "Injecting jar into unknown runner {}. It might not work as expected. " + + "If you are experiencing issues with running and/or job submission, " + + "please fill github issue reporting the name of the runner.", + runnerName); + } + injectJarIntoContextClassLoader(paths); + } + } + + /** Inject given paths to class loader of given (local) runner. */ + public static void injectJarIntoContextClassLoader(Collection paths) { + ClassLoader loader = Thread.currentThread().getContextClassLoader(); + URL[] urls = + paths.stream() + .map(p -> ExceptionUtils.uncheckedFactory(() -> p.toURI().toURL())) + .collect(Collectors.toList()) + .toArray(new URL[] {}); + Thread.currentThread().setContextClassLoader(new URLClassLoader(urls, loader)); + } + + /** + * Generate jar from provided map of dynamic classes. + * + * @param classes map of class to bytecode + * @return generated {@link File} + */ + public static File createJarFromDynamicClasses(Map, byte[]> classes) + throws IOException { + File out = File.createTempFile("proxima-beam-dynamic", ".jar"); + out.deleteOnExit(); + try (JarOutputStream output = new JarOutputStream(new FileOutputStream(out))) { + long now = System.currentTimeMillis(); + for (Map.Entry, byte[]> e : classes.entrySet()) { + String name = e.getKey().getName().replace('.', '/') + ".class"; + JarEntry entry = new JarEntry(name); + entry.setTime(now); + output.putNextEntry(entry); + InputStream input = new ByteArrayInputStream(e.getValue()); + IOUtils.copy(input, output); + output.closeEntry(); + } + } + return out; + } + + private static List addToList( + @Nonnull List first, @Nullable List second) { + Collection res = new HashSet<>(first); + if (second != null) { + res.addAll(second); + } + return new ArrayList<>(res); + } + + private RunnerUtils() {} +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ClassCollector.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ClassCollector.java new file mode 100644 index 000000000..478475807 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ClassCollector.java @@ -0,0 +1,27 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +interface ClassCollector { + + /** + * Collect generated classes for dispatching them to runners. + * + * @param cls the class to collect + * @param byteCode the bytecode of the class + */ + void collect(Class cls, byte[] byteCode); +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java new file mode 100644 index 000000000..e6976f66a --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java @@ -0,0 +1,1015 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static cz.o2.proxima.beam.util.state.MethodCallUtils.getStateReaders; +import static cz.o2.proxima.beam.util.state.MethodCallUtils.getWrapperInputType; + +import cz.o2.proxima.beam.util.state.MethodCallUtils.VoidMethodInvoker; +import cz.o2.proxima.core.functional.BiFunction; +import cz.o2.proxima.core.functional.UnaryFunction; +import cz.o2.proxima.core.util.ExceptionUtils; +import cz.o2.proxima.core.util.Pair; +import cz.o2.proxima.internal.com.google.common.base.MoreObjects; +import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import cz.o2.proxima.internal.com.google.common.collect.Iterables; +import java.io.Serializable; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.annotation.AnnotationDescription; +import net.bytebuddy.description.modifier.FieldManifestation; +import net.bytebuddy.description.modifier.Visibility; +import net.bytebuddy.description.type.TypeDefinition; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.description.type.TypeDescription.Generic; +import net.bytebuddy.dynamic.DynamicType.Builder; +import net.bytebuddy.dynamic.DynamicType.Builder.MethodDefinition; +import net.bytebuddy.dynamic.DynamicType.Loaded; +import net.bytebuddy.dynamic.DynamicType.Unloaded; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.FieldAccessor; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.Implementation.Composable; +import net.bytebuddy.implementation.MethodCall; +import net.bytebuddy.implementation.MethodDelegation; +import net.bytebuddy.implementation.bind.annotation.AllArguments; +import net.bytebuddy.implementation.bind.annotation.RuntimeType; +import net.bytebuddy.implementation.bind.annotation.This; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement; +import org.apache.beam.sdk.runners.TransformHierarchy; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.StateId; +import org.apache.beam.sdk.transforms.DoFn.TimerId; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.ByteBuddyUtils; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.construction.ReplacementOutputs; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.reflect.TypeToken; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +@Slf4j +class ExpandContext { + + static final String EXPANDER_BUF_STATE_SPEC = "expanderBufStateSpec"; + static final String EXPANDER_BUF_STATE_NAME = "_expanderBuf"; + static final String EXPANDER_FLUSH_STATE_SPEC = "expanderFlushStateSpec"; + static final String EXPANDER_FLUSH_STATE_NAME = "_expanderFlush"; + static final String EXPANDER_TIMER_SPEC = "expanderTimerSpec"; + static final String EXPANDER_TIMER_NAME = "_expanderTimer"; + static final String DELEGATE_FIELD_NAME = "delegate"; + static final String PROCESS_ELEMENT_INTERCEPTOR_FIELD_NAME = "__processElementInterceptor"; + static final String ON_WINDOW_INTERCEPTOR_FIELD_NAME = "__onWindowInterceptor"; + static final String FLUSH_TIMER_INTERCEPTOR_FIELD_NAME = "__flushTimerInterceptor"; + + static final TupleTag STATE_TUPLE_TAG = new StateTupleTag() {}; + + private static class StateTupleTag extends TupleTag {} + + @Getter + class DoFnExpandContext { + + private final DoFn, ?> doFn; + private final TupleTag mainTag; + private final KvCoder inputCoder; + private final Coder keyCoder; + + private final Class, ?>> doFnClass; + + private final ParameterizedType inputType; + private final Type outputType; + private final Generic doFnGeneric; + + private final Method processElement; + private final @Nullable Method onWindowMethod; + + private final FlushTimerParameterExpander flushExpander; + private final OnWindowParameterExpander onWindowExpander; + private final ProcessElementParameterExpander processElementExpander; + + private final FlushTimerInterceptor flushTimerInterceptor; + private final OnWindowExpirationInterceptor onWindowExpirationInterceptor; + private final ProcessElementInterceptor processElementInterceptor; + + @SuppressWarnings("unchecked") + DoFnExpandContext(DoFn, ?> doFn, KvCoder inputCoder, TupleTag mainTag) + throws InvocationTargetException, + NoSuchMethodException, + InstantiationException, + IllegalAccessException { + + this.doFn = doFn; + this.inputCoder = inputCoder; + this.mainTag = mainTag; + this.keyCoder = inputCoder.getKeyCoder(); + this.doFnClass = (Class, ?>>) doFn.getClass(); + this.processElement = findMethod(doFn, DoFn.ProcessElement.class); + this.onWindowMethod = findMethod(doFn, DoFn.OnWindowExpiration.class); + + ParameterizedType parameterizedSuperClass = getParameterizedDoFn(doFnClass); + this.inputType = (ParameterizedType) parameterizedSuperClass.getActualTypeArguments()[0]; + Preconditions.checkArgument( + inputType.getRawType().equals(KV.class), + "Input type to stateful DoFn must be KV, go %s", + inputType); + + this.outputType = parameterizedSuperClass.getActualTypeArguments()[1]; + Generic wrapperInput = getWrapperInputType(inputType); + + this.doFnGeneric = + Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(DoFn.class), + wrapperInput, + TypeDescription.Generic.Builder.of(outputType).build()) + .build(); + + this.flushExpander = + FlushTimerParameterExpander.of(doFn, inputType, processElement, mainTag, outputType); + this.onWindowExpander = + OnWindowParameterExpander.of( + inputType, processElement, onWindowMethod, mainTag, outputType); + this.processElementExpander = + ProcessElementParameterExpander.of( + doFn, processElement, inputType, mainTag, outputType, stateWriteInstant); + + this.flushTimerInterceptor = + new FlushTimerInterceptor<>( + doFn, + processElement, + flushExpander, + keyCoder, + STATE_TUPLE_TAG, + nextFlushInstantFn, + buddy, + collector); + VoidMethodInvoker processElementInvoker = + VoidMethodInvoker.of(processElement, buddy, collector); + @Nullable VoidMethodInvoker onWindowInvoker = + onWindowMethod == null ? null : VoidMethodInvoker.of(onWindowMethod, buddy, collector); + this.onWindowExpirationInterceptor = + new OnWindowExpirationInterceptor<>( + processElementInvoker, onWindowInvoker, onWindowExpander); + this.processElementInterceptor = + new ProcessElementInterceptor<>(processElementExpander, processElementInvoker); + } + } + + private final ByteBuddy buddy = new ByteBuddy(); + private final Map, byte[]> generatedClasses = new HashMap<>(); + private final ClassCollector collector = generatedClasses::put; + + private final PTransform>> inputs; + private final Instant stateWriteInstant; + private final UnaryFunction nextFlushInstantFn; + private final PTransform>, PDone> stateSink; + + public ExpandContext( + PTransform>> inputs, + Instant stateWriteInstant, + UnaryFunction nextFlushInstantFn, + PTransform>, PDone> stateSink) { + + this.inputs = inputs; + this.stateWriteInstant = stateWriteInstant; + this.nextFlushInstantFn = nextFlushInstantFn; + this.stateSink = stateSink; + } + + Pipeline expand(Pipeline pipeline) { + validatePipeline(pipeline); + // collect generated classes + pipeline.getCoderRegistry().registerCoderForClass(StateValue.class, StateValue.coder()); + PCollection> inputsMaterialized = pipeline.apply(inputs); + // replace all MultiParDos + pipeline.replaceAll(Collections.singletonList(statefulParMultiDoOverride(inputsMaterialized))); + // collect all StateValues + List>> stateValues = new ArrayList<>(); + + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + @Override + public void visitPrimitiveTransform(TransformHierarchy.Node node) { + if (node.getTransform() instanceof ParDo.MultiOutput) { + node.getOutputs().entrySet().stream() + .filter(e -> e.getKey() instanceof StateTupleTag) + .map(Entry::getValue) + .findAny() + .ifPresent( + p -> + stateValues.add( + Pair.of(node.getFullName(), (PCollection) p))); + } + } + }); + if (!stateValues.isEmpty()) { + PCollectionList> list = PCollectionList.empty(pipeline); + for (Pair> p : stateValues) { + PCollection> mapped = p.getSecond().apply(WithKeys.of(p.getFirst())); + list = list.and(mapped); + } + list.apply(Flatten.pCollections()).apply(stateSink); + } + return pipeline; + } + + Map, byte[]> getGeneratedClasses() { + return generatedClasses; + } + + private static void validatePipeline(org.apache.beam.sdk.Pipeline pipeline) { + // check that all nodes have unique names + Set names = new HashSet<>(); + pipeline.traverseTopologically( + new PipelineVisitor.Defaults() { + + @Override + public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) { + Preconditions.checkState(names.add(node.getFullName())); + return CompositeBehavior.ENTER_TRANSFORM; + } + + @Override + public void visitPrimitiveTransform(TransformHierarchy.Node node) { + Preconditions.checkState(names.add(node.getFullName())); + } + }); + } + + private PTransformOverride statefulParMultiDoOverride( + PCollection> inputs) { + + return PTransformOverride.of( + application -> application.getTransform() instanceof ParDo.MultiOutput, + parMultiDoReplacementFactory(inputs)); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private PTransformOverrideFactory parMultiDoReplacementFactory( + PCollection> inputs) { + + return new PTransformOverrideFactory<>() { + @Override + public PTransformReplacement getReplacementTransform(AppliedPTransform transform) { + return replaceParMultiDo(transform, inputs); + } + + @SuppressWarnings("unchecked") + @Override + public Map, ReplacementOutput> mapOutputs(Map outputs, POutput newOutput) { + return ReplacementOutputs.tagged(outputs, newOutput); + } + }; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private PTransformReplacement replaceParMultiDo( + AppliedPTransform transform, PCollection> inputs) { + + ParDo.MultiOutput rawTransform = + (ParDo.MultiOutput) (PTransform) transform.getTransform(); + DoFn, ?> doFn = (DoFn) rawTransform.getFn(); + PInput pMainInput = getMainInput(transform); + if (!DoFnSignatures.isStateful(doFn)) { + return PTransformReplacement.of(pMainInput, (PTransform) transform.getTransform()); + } + String transformName = transform.getFullName(); + PCollection transformInputs = + inputs + .apply(Filter.by(kv -> kv.getKey().equals(transformName))) + .apply(MapElements.into(TypeDescriptor.of(StateValue.class)).via(KV::getValue)); + TupleTag mainOutputTag = (TupleTag) rawTransform.getMainOutputTag(); + return PTransformReplacement.of( + pMainInput, + transformedParDo( + transformInputs, + (DoFn) doFn, + mainOutputTag, + TupleTagList.of( + transform.getOutputs().keySet().stream() + .filter(t -> !t.equals(mainOutputTag)) + .collect(Collectors.toList())))); + } + + @SuppressWarnings("unchecked") + private > + PTransform, PCollectionTuple> transformedParDo( + PCollection transformInputs, + DoFn, ?> doFn, + TupleTag mainOutputTag, + TupleTagList otherOutputs) { + + return new PTransform<>() { + @Override + public PCollectionTuple expand(PCollection input) { + @SuppressWarnings("unchecked") + KvCoder coder = (KvCoder) input.getCoder(); + Coder keyCoder = coder.getKeyCoder(); + Coder valueCoder = coder.getValueCoder(); + TypeDescriptor> valueDescriptor = + new TypeDescriptor<>(new TypeToken>() {}) {}; + PCollection>> state = + transformInputs + .apply( + MapElements.into( + TypeDescriptors.kvs( + keyCoder.getEncodedTypeDescriptor(), valueDescriptor)) + .via( + e -> + ExceptionUtils.uncheckedFactory( + () -> + KV.of( + CoderUtils.decodeFromByteArray(keyCoder, e.getKey()), + StateOrInput.state(e))))) + .setCoder(KvCoder.of(keyCoder, StateOrInput.coder(valueCoder))); + PCollection>> inputs = + input + .apply( + MapElements.into( + TypeDescriptors.kvs( + keyCoder.getEncodedTypeDescriptor(), valueDescriptor)) + .via(e -> KV.of(e.getKey(), StateOrInput.input(e.getValue())))) + .setCoder(KvCoder.of(keyCoder, StateOrInput.coder(valueCoder))); + PCollection>> flattened = + PCollectionList.of(state).and(inputs).apply(Flatten.pCollections()); + // FIXME: add name? + PCollectionTuple tuple = + flattened.apply( + ParDo.of(transformedDoFn(doFn, (KvCoder) input.getCoder(), mainOutputTag)) + .withOutputTags(mainOutputTag, otherOutputs.and(STATE_TUPLE_TAG))); + PCollectionTuple res = PCollectionTuple.empty(input.getPipeline()); + for (Entry, PCollection> e : + (Set, PCollection>>) (Set) tuple.getAll().entrySet()) { + if (!e.getKey().equals(STATE_TUPLE_TAG)) { + res = res.and(e.getKey(), e.getValue()); + } + } + return res; + } + }; + } + + private >> DoFn transformedDoFn( + DoFn, ?> doFn, KvCoder inputCoder, TupleTag mainTag) { + + DoFnExpandContext context = + ExceptionUtils.uncheckedFactory(() -> new DoFnExpandContext<>(doFn, inputCoder, mainTag)); + + Class, ?>> doFnClass = context.getDoFnClass(); + ClassLoadingStrategy strategy = ByteBuddyUtils.getClassLoadingStrategy(doFnClass); + final String className = doFnClass.getName() + "$Expanded"; + final ClassLoader classLoader = ExternalStateExpander.class.getClassLoader(); + try { + @SuppressWarnings("unchecked") + Class> aClass = + (Class>) classLoader.loadClass(className); + // class found, return instance + return newInstance(aClass, context); + } catch (ClassNotFoundException e) { + // class not found, create it + } + + @SuppressWarnings("unchecked") + Builder> builder = + (Builder>) + buddy.subclass(context.getDoFnGeneric()).name(className).implement(DoFnProvider.class); + + ParameterizedType inputType = context.getInputType(); + builder = defineInvokerFields(doFnClass, inputType, builder); + builder = addStateAndTimers(doFnClass, inputType, builder); + builder = + builder + .defineConstructor(Visibility.PUBLIC) + .withParameters( + doFnClass, + context.getFlushTimerInterceptor().getClass(), + context.getOnWindowExpirationInterceptor().getClass(), + context.getProcessElementInterceptor().getClass()) + .intercept( + addStateAndTimerValues( + doFn, + inputCoder, + MethodCall.invoke( + ExceptionUtils.uncheckedFactory(() -> DoFn.class.getConstructor())) + .andThen(FieldAccessor.ofField(DELEGATE_FIELD_NAME).setsArgumentAt(0)) + .andThen( + FieldAccessor.ofField(FLUSH_TIMER_INTERCEPTOR_FIELD_NAME) + .setsArgumentAt(1)) + .andThen( + FieldAccessor.ofField(ON_WINDOW_INTERCEPTOR_FIELD_NAME) + .setsArgumentAt(2)) + .andThen( + FieldAccessor.ofField(PROCESS_ELEMENT_INTERCEPTOR_FIELD_NAME) + .setsArgumentAt(3)))); + + builder = addProcessingMethods(context, builder); + builder = implementDoFnProvider(builder); + + Unloaded> dynamicClass = builder.make(); + Loaded> unloaded = dynamicClass.load(null, strategy); + Class> cls = unloaded.getLoaded(); + collector.collect(cls, unloaded.getBytes()); + return newInstance(cls, context); + } + + private >> + Builder> implementDoFnProvider(Builder> builder) { + + return builder + .defineMethod("getDoFn", DoFn.class, Visibility.PUBLIC) + .intercept(FieldAccessor.ofField(DELEGATE_FIELD_NAME)); + } + + private >, V, K> DoFn newInstance( + Class> cls, DoFnExpandContext context) { + + try { + Constructor> ctor = + cls.getDeclaredConstructor( + context.getDoFnClass(), + context.getFlushTimerInterceptor().getClass(), + context.getOnWindowExpirationInterceptor().getClass(), + context.getProcessElementInterceptor().getClass()); + @SuppressWarnings("unchecked") + DoFn instance = + (DoFn) + ctor.newInstance( + context.getDoFn(), + context.getFlushTimerInterceptor(), + context.getOnWindowExpirationInterceptor(), + context.getProcessElementInterceptor()); + return instance; + } catch (Exception ex) { + throw new IllegalStateException(String.format("Cannot instantiate class %s", cls), ex); + } + } + + private >> + Builder> defineInvokerFields( + Class, ?>> doFnClass, + ParameterizedType inputType, + Builder> builder) { + + int privateFinal = Visibility.PRIVATE.getMask() + FieldManifestation.FINAL.getMask(); + Type keyType = inputType.getActualTypeArguments()[0]; + Type valueType = inputType.getActualTypeArguments()[1]; + Generic processInterceptor = + Generic.Builder.parameterizedType(ProcessElementInterceptor.class, keyType, valueType) + .build(); + Generic onWindowInterceptor = + Generic.Builder.parameterizedType(OnWindowExpirationInterceptor.class, keyType, valueType) + .build(); + Generic flushTimerInterceptor = + Generic.Builder.parameterizedType(FlushTimerInterceptor.class, keyType, valueType).build(); + + return builder + .defineField(DELEGATE_FIELD_NAME, doFnClass, privateFinal) + .defineField(PROCESS_ELEMENT_INTERCEPTOR_FIELD_NAME, processInterceptor, privateFinal) + .defineField(ON_WINDOW_INTERCEPTOR_FIELD_NAME, onWindowInterceptor, privateFinal) + .defineField(FLUSH_TIMER_INTERCEPTOR_FIELD_NAME, flushTimerInterceptor, privateFinal); + } + + private static , OutputT> Implementation addStateAndTimerValues( + DoFn doFn, Coder> inputCoder, Composable delegate) { + + List> acceptable = Arrays.asList(StateId.class, TimerId.class); + @SuppressWarnings("unchecked") + Class> doFnClass = + (Class>) doFn.getClass(); + for (Field f : doFnClass.getDeclaredFields()) { + if (!Modifier.isStatic(f.getModifiers()) + && acceptable.stream().anyMatch(a -> f.getAnnotation(a) != null)) { + f.setAccessible(true); + Object value = ExceptionUtils.uncheckedFactory(() -> f.get(doFn)); + delegate = delegate.andThen(FieldAccessor.ofField(f.getName()).setsValue(value)); + } + } + delegate = + delegate + .andThen( + FieldAccessor.ofField(EXPANDER_BUF_STATE_SPEC) + .setsValue(StateSpecs.bag(TimestampedValueCoder.of(inputCoder)))) + .andThen(FieldAccessor.ofField(EXPANDER_FLUSH_STATE_SPEC).setsValue(StateSpecs.value())) + .andThen( + FieldAccessor.ofField(EXPANDER_TIMER_SPEC) + .setsValue(TimerSpecs.timer(TimeDomain.EVENT_TIME))); + return delegate; + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static ParameterizedType getParameterizedDoFn( + Class> doFnClass) { + + Type type = doFnClass.getGenericSuperclass(); + if (type instanceof ParameterizedType) { + return (ParameterizedType) type; + } + if (doFnClass.getSuperclass().isAssignableFrom(DoFn.class)) { + return getParameterizedDoFn((Class) doFnClass.getGenericSuperclass()); + } + throw new IllegalStateException("Cannot get parameterized type of " + doFnClass); + } + + private >> + Builder> addProcessingMethods( + DoFnExpandContext context, Builder> builder) { + + DoFn, ?> doFn = context.getDoFn(); + builder = addProcessingMethod(doFn, DoFn.Setup.class, builder); + builder = addProcessingMethod(doFn, DoFn.StartBundle.class, builder); + builder = addProcessElementMethod(context, builder); + builder = addOnWindowExpirationMethod(context, builder); + builder = addTimerFlushMethod(context, builder); + builder = addProcessingMethod(doFn, DoFn.FinishBundle.class, builder); + builder = addProcessingMethod(doFn, DoFn.Teardown.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetInitialRestriction.class, builder); + builder = addProcessingMethod(doFn, DoFn.SplitRestriction.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetRestrictionCoder.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetWatermarkEstimatorStateCoder.class, builder); + builder = addProcessingMethod(doFn, DoFn.GetInitialWatermarkEstimatorState.class, builder); + builder = addProcessingMethod(doFn, DoFn.NewWatermarkEstimator.class, builder); + builder = addProcessingMethod(doFn, DoFn.NewTracker.class, builder); + builder = addProcessingMethod(doFn, DoFn.OnTimer.class, builder); + + return builder; + } + + private >> + Builder> addProcessElementMethod( + DoFnExpandContext context, Builder> builder) { + + Method method = Objects.requireNonNull(context.getProcessElement()); + ProcessElementParameterExpander expander = context.getProcessElementExpander(); + List> wrapperArgs = expander.getWrapperArgs(); + Preconditions.checkArgument(void.class.isAssignableFrom(method.getReturnType())); + MethodDefinition> methodDefinition = + builder + .defineMethod(method.getName(), void.class, Visibility.PUBLIC) + .withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) + .intercept(MethodDelegation.toField(PROCESS_ELEMENT_INTERCEPTOR_FIELD_NAME)); + + for (int i = 0; i < wrapperArgs.size(); i++) { + Pair arg = wrapperArgs.get(i); + if (arg.getFirst() != null) { + methodDefinition = methodDefinition.annotateParameter(i, arg.getFirst()); + } + } + return methodDefinition.annotateMethod(method.getDeclaredAnnotations()); + } + + private >> + Builder> addOnWindowExpirationMethod( + DoFnExpandContext context, Builder> builder) { + + Class annotation = DoFn.OnWindowExpiration.class; + @Nullable Method onWindowExpirationMethod = context.getOnWindowMethod(); + OnWindowParameterExpander expander = context.getOnWindowExpander(); + List> wrapperArgs = expander.getWrapperArgs(); + MethodDefinition> methodDefinition = + builder + .defineMethod( + onWindowExpirationMethod == null + ? "_onWindowExpiration" + : onWindowExpirationMethod.getName(), + void.class, + Visibility.PUBLIC) + .withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) + .intercept(MethodDelegation.toField(ON_WINDOW_INTERCEPTOR_FIELD_NAME)); + + // retrieve parameter annotations and apply them + for (int i = 0; i < wrapperArgs.size(); i++) { + AnnotationDescription ann = wrapperArgs.get(i).getFirst(); + if (ann != null) { + methodDefinition = methodDefinition.annotateParameter(i, ann); + } + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(annotation).build()); + } + + private >> + Builder> addTimerFlushMethod( + DoFnExpandContext context, Builder> builder) { + + FlushTimerParameterExpander expander = context.getFlushExpander(); + List> wrapperArgs = expander.getWrapperArgs(); + MethodDefinition> methodDefinition = + builder + .defineMethod("expanderFlushTimer", void.class, Visibility.PUBLIC) + .withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) + .intercept(MethodDelegation.toField(FLUSH_TIMER_INTERCEPTOR_FIELD_NAME)); + int i = 0; + for (Pair p : wrapperArgs) { + if (p.getFirst() != null) { + methodDefinition = methodDefinition.annotateParameter(i, p.getFirst()); + } + i++; + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(DoFn.OnTimer.class) + .define("value", EXPANDER_TIMER_NAME) + .build()); + } + + private static @Nullable Method findMethod( + DoFn, OutputT> doFn, Class annotation) { + + return Iterables.getOnlyElement( + Arrays.stream(doFn.getClass().getMethods()) + .filter(m -> m.getAnnotation(annotation) != null) + .collect(Collectors.toList()), + null); + } + + private static >, T extends Annotation> + Builder> addProcessingMethod( + DoFn, ?> doFn, Class annotation, Builder> builder) { + + Method method = findMethod(doFn, annotation); + if (method != null) { + MethodDefinition> methodDefinition = + builder + .defineMethod(method.getName(), method.getReturnType(), Visibility.PUBLIC) + .withParameters(method.getGenericParameterTypes()) + .intercept(MethodCall.invoke(method).onField(DELEGATE_FIELD_NAME).withAllArguments()); + + // retrieve parameter annotations and apply them + Annotation[][] parameterAnnotations = method.getParameterAnnotations(); + for (int i = 0; i < parameterAnnotations.length; i++) { + for (Annotation paramAnnotation : parameterAnnotations[i]) { + methodDefinition = methodDefinition.annotateParameter(i, paramAnnotation); + } + } + return methodDefinition.annotateMethod( + AnnotationDescription.Builder.ofType(annotation).build()); + } + return builder; + } + + private static >> + Builder> addStateAndTimers( + Class, ?>> doFnClass, + ParameterizedType inputType, + Builder> builder) { + + builder = cloneFields(doFnClass, StateId.class, builder); + builder = cloneFields(doFnClass, TimerId.class, builder); + builder = addBufferingStatesAndTimer(inputType, builder); + return builder; + } + + /** Add state that buffers inputs until we process all state updates. */ + private static >> + Builder> addBufferingStatesAndTimer( + ParameterizedType inputType, Builder> builder) { + + // type: StateSpec>> + Generic bufStateSpecFieldType = + Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(StateSpec.class), bagStateFromInputType(inputType)) + .build(); + // type: StateSpec> + Generic finishedStateSpecFieldType = + Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(StateSpec.class), + Generic.Builder.parameterizedType(ValueState.class, Instant.class).build()) + .build(); + + Generic timerSpecFieldType = Generic.Builder.of(TimerSpec.class).build(); + + builder = + defineStateField( + builder, + bufStateSpecFieldType, + DoFn.StateId.class, + EXPANDER_BUF_STATE_SPEC, + EXPANDER_BUF_STATE_NAME); + builder = + defineStateField( + builder, + finishedStateSpecFieldType, + DoFn.StateId.class, + EXPANDER_FLUSH_STATE_SPEC, + EXPANDER_FLUSH_STATE_NAME); + builder = + defineStateField( + builder, + timerSpecFieldType, + DoFn.TimerId.class, + EXPANDER_TIMER_SPEC, + EXPANDER_TIMER_NAME); + + return builder; + } + + private static >> + Builder> defineStateField( + Builder> builder, + Generic stateSpecFieldType, + Class annotation, + String fieldName, + String name) { + + return builder + .defineField( + fieldName, + stateSpecFieldType, + Visibility.PUBLIC.getMask() + FieldManifestation.FINAL.getMask()) + .annotateField( + AnnotationDescription.Builder.ofType(annotation).define("value", name).build()); + } + + private static >, T extends Annotation> + Builder> cloneFields( + Class, ?>> doFnClass, + Class annotationClass, + Builder> builder) { + + for (Field f : doFnClass.getDeclaredFields()) { + if (!Modifier.isStatic(f.getModifiers()) && f.getAnnotation(annotationClass) != null) { + builder = + builder + .defineField(f.getName(), f.getGenericType(), f.getModifiers()) + .annotateField(f.getDeclaredAnnotations()); + } + } + return builder; + } + + private static PInput getMainInput(AppliedPTransform transform) { + Map, PCollection> mainInputs = transform.getMainInputs(); + if (mainInputs.size() == 1) { + return Iterables.getOnlyElement(mainInputs.values()); + } + return asTuple(mainInputs); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static @NonNull PInput asTuple(Map, PCollection> mainInputs) { + Preconditions.checkArgument(!mainInputs.isEmpty()); + PCollectionTuple res = null; + for (Map.Entry, PCollection> e : mainInputs.entrySet()) { + if (res == null) { + res = PCollectionTuple.of((TupleTag) e.getKey(), e.getValue()); + } else { + res = res.and((TupleTag) e.getKey(), e.getValue()); + } + } + return Objects.requireNonNull(res); + } + + public static class ProcessElementInterceptor implements Serializable { + + private final ProcessElementParameterExpander expander; + private final UnaryFunction processFn; + private final VoidMethodInvoker, ?>> invoker; + + private ProcessElementInterceptor( + ProcessElementParameterExpander expander, + VoidMethodInvoker, ?>> processInvoker) { + + this.expander = expander; + this.processFn = expander.getProcessFn(); + this.invoker = processInvoker; + } + + @RuntimeType + public void intercept( + @This DoFn>, ?> proxy, @AllArguments Object[] allArgs) { + + if (Boolean.TRUE.equals(processFn.apply(allArgs))) { + Object[] methodArgs = expander.getProcessElementArgs(allArgs); + DoFn, ?> doFn = ((DoFnProvider) proxy).getDoFn(); + ExceptionUtils.unchecked(() -> invoker.invoke(doFn, methodArgs)); + } + } + } + + public static class OnWindowExpirationInterceptor implements Serializable { + + private final VoidMethodInvoker, ?>> processElement; + private final @Nullable VoidMethodInvoker, ?>> onWindowExpiration; + private final OnWindowParameterExpander expander; + + public OnWindowExpirationInterceptor( + VoidMethodInvoker, ?>> processElementInvoker, + @Nullable VoidMethodInvoker, ?>> onWindowExpirationInvoker, + OnWindowParameterExpander expander) { + + this.processElement = + Objects.requireNonNull(processElementInvoker, "Missing @ProcessElement"); + this.onWindowExpiration = onWindowExpirationInvoker; + this.expander = expander; + } + + @RuntimeType + public void intercept( + @This DoFn>, ?> proxy, @AllArguments Object[] allArgs) { + + @SuppressWarnings("unchecked") + BagState>> buf = + (BagState>>) allArgs[allArgs.length - 1]; + Iterable>> buffered = buf.read(); + // feed all data to @ProcessElement + for (TimestampedValue> kv : buffered) { + DoFn, ?> doFn = ((DoFnProvider) proxy).getDoFn(); + ExceptionUtils.unchecked( + () -> processElement.invoke(doFn, expander.getProcessElementArgs(kv, allArgs))); + } + // invoke onWindowExpiration + if (onWindowExpiration != null) { + DoFn, ?> doFn = ((DoFnProvider) proxy).getDoFn(); + ExceptionUtils.unchecked( + () -> onWindowExpiration.invoke(doFn, expander.getOnWindowExpirationArgs(allArgs))); + } + } + } + + public static class FlushTimerInterceptor implements Serializable { + + private final LinkedHashMap>> + stateReaders; + private final VoidMethodInvoker processElementMethod; + private final FlushTimerParameterExpander expander; + private final Coder keyCoder; + private final TupleTag stateTag; + private final UnaryFunction nextFlushInstantFn; + + FlushTimerInterceptor( + DoFn, ?> doFn, + Method processElementMethod, + FlushTimerParameterExpander expander, + Coder keyCoder, + TupleTag stateTag, + UnaryFunction nextFlushInstantFn, + ByteBuddy buddy, + ClassCollector collector) { + + this.stateReaders = getStateReaders(doFn); + this.processElementMethod = + ExceptionUtils.uncheckedFactory( + () -> VoidMethodInvoker.of(processElementMethod, buddy, collector)); + this.expander = expander; + this.keyCoder = keyCoder; + this.stateTag = stateTag; + this.nextFlushInstantFn = nextFlushInstantFn; + } + + @RuntimeType + public void intercept(@This DoFn>, ?> doFn, @AllArguments Object[] args) { + Instant now = (Instant) args[args.length - 6]; + @SuppressWarnings("unchecked") + K key = (K) args[args.length - 5]; + @SuppressWarnings("unchecked") + ValueState nextFlushState = (ValueState) args[args.length - 3]; + Timer flushTimer = (Timer) args[args.length - 4]; + Instant nextFlush = nextFlushInstantFn.apply(now); + Instant lastFlush = nextFlushState.read(); + boolean isNextScheduled = + nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE); + if (isNextScheduled) { + flushTimer.set(nextFlush); + nextFlushState.write(nextFlush); + } + @SuppressWarnings("unchecked") + BagState>> bufState = + (BagState>>) args[args.length - 2]; + @SuppressWarnings({"unchecked", "rawtypes"}) + List>> pushedBackElements = + processBuffer( + (DoFnProvider) doFn, + args, + (Iterable) bufState.read(), + MoreObjects.firstNonNull(lastFlush, now)); + bufState.clear(); + // if we have already processed state data + if (lastFlush != null) { + MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1]; + OutputReceiver output = outputReceiver.get(stateTag); + byte[] keyBytes = + ExceptionUtils.uncheckedFactory(() -> CoderUtils.encodeToByteArray(keyCoder, key)); + int i = 0; + for (BiFunction> f : stateReaders.values()) { + Object accessor = args[i++]; + Iterable values = f.apply(accessor, keyBytes); + values.forEach(output::output); + } + } + @SuppressWarnings({"unchecked", "rawtypes"}) + List>> remaining = + processBuffer( + (DoFnProvider) doFn, + args, + (List) pushedBackElements, + MoreObjects.firstNonNull(nextFlush, BoundedWindow.TIMESTAMP_MAX_VALUE)); + remaining.forEach(bufState::add); + } + + private List>> processBuffer( + DoFnProvider provider, + Object[] args, + Iterable>> buffer, + Instant maxTs) { + + List>> pushedBackElements = new ArrayList<>(); + buffer.forEach( + kv -> { + if (kv.getTimestamp().isBefore(maxTs)) { + Object[] processArgs = expander.getProcessElementArgs(kv, args); + DoFn, ?> doFn = provider.getDoFn(); + ExceptionUtils.unchecked(() -> processElementMethod.invoke(doFn, processArgs)); + } else { + // return back to buffer + log.debug("Returning element {} to flush buffer", kv); + pushedBackElements.add(kv); + } + }); + @SuppressWarnings({"unchecked", "rawtypes"}) + List>> res = (List) pushedBackElements; + return res; + } + } + + static Generic bagStateFromInputType(ParameterizedType inputType) { + return Generic.Builder.parameterizedType( + TypeDescription.ForLoadedType.of(BagState.class), + Generic.Builder.parameterizedType(TimestampedValue.class, inputType).build()) + .build(); + } + + public interface DoFnProvider { + DoFn, ?> getDoFn(); + } +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java index 5f80db317..9f38b3f43 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java @@ -17,117 +17,23 @@ import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; +import cz.o2.proxima.beam.util.RunnerUtils; import cz.o2.proxima.core.functional.UnaryFunction; -import cz.o2.proxima.core.util.ExceptionUtils; -import cz.o2.proxima.core.util.Pair; -import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting; -import cz.o2.proxima.internal.com.google.common.base.MoreObjects; -import cz.o2.proxima.internal.com.google.common.base.Preconditions; -import cz.o2.proxima.internal.com.google.common.collect.Iterables; -import java.lang.annotation.Annotation; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; -import java.util.ArrayList; -import java.util.Arrays; +import java.io.File; +import java.io.IOException; import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; -import java.util.Set; -import java.util.function.BiFunction; -import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import net.bytebuddy.ByteBuddy; -import net.bytebuddy.description.annotation.AnnotationDescription; -import net.bytebuddy.description.modifier.FieldManifestation; -import net.bytebuddy.description.modifier.Visibility; -import net.bytebuddy.description.type.TypeDefinition; -import net.bytebuddy.description.type.TypeDescription; -import net.bytebuddy.description.type.TypeDescription.Generic; -import net.bytebuddy.dynamic.DynamicType.Builder; -import net.bytebuddy.dynamic.DynamicType.Builder.MethodDefinition; -import net.bytebuddy.dynamic.DynamicType.Unloaded; -import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; -import net.bytebuddy.implementation.FieldAccessor; -import net.bytebuddy.implementation.Implementation; -import net.bytebuddy.implementation.Implementation.Composable; -import net.bytebuddy.implementation.MethodCall; -import net.bytebuddy.implementation.MethodDelegation; -import net.bytebuddy.implementation.bind.annotation.AllArguments; -import net.bytebuddy.implementation.bind.annotation.RuntimeType; -import net.bytebuddy.implementation.bind.annotation.This; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.Pipeline.PipelineVisitor; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.runners.PTransformOverride; -import org.apache.beam.sdk.runners.PTransformOverrideFactory; -import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement; -import org.apache.beam.sdk.runners.TransformHierarchy; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.StateSpec; -import org.apache.beam.sdk.state.StateSpecs; -import org.apache.beam.sdk.state.TimeDomain; -import org.apache.beam.sdk.state.Timer; -import org.apache.beam.sdk.state.TimerSpec; -import org.apache.beam.sdk.state.TimerSpecs; -import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; -import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; -import org.apache.beam.sdk.transforms.DoFn.ProcessElement; -import org.apache.beam.sdk.transforms.DoFn.StateId; -import org.apache.beam.sdk.transforms.DoFn.TimerId; -import org.apache.beam.sdk.transforms.Filter; -import org.apache.beam.sdk.transforms.Flatten; -import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.WithKeys; -import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.ByteBuddyUtils; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.util.construction.ReplacementOutputs; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionList; -import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PDone; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.TimestampedValue; -import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.sdk.values.TypeDescriptors; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.reflect.TypeToken; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @Slf4j public class ExternalStateExpander { - static final String EXPANDER_BUF_STATE_SPEC = "expanderBufStateSpec"; - static final String EXPANDER_BUF_STATE_NAME = "_expanderBuf"; - static final String EXPANDER_FLUSH_STATE_SPEC = "expanderFlushStateSpec"; - static final String EXPANDER_FLUSH_STATE_NAME = "_expanderFlush"; - static final String EXPANDER_TIMER_SPEC = "expanderTimerSpec"; - static final String EXPANDER_TIMER_NAME = "_expanderTimer"; - static final String DELEGATE_FIELD_NAME = "delegate"; - - static final TupleTag STATE_TUPLE_TAG = new StateTupleTag() {}; - /** * Expand the given @{link Pipeline} to support external state store and restore * @@ -142,822 +48,20 @@ public static Pipeline expand( PTransform>> inputs, Instant stateWriteInstant, UnaryFunction nextFlushInstantFn, - PTransform>, PDone> stateSink) { - - validatePipeline(pipeline); - pipeline.getCoderRegistry().registerCoderForClass(StateValue.class, StateValue.coder()); - PCollection> inputsMaterialized = pipeline.apply(inputs); - // replace all MultiParDos - pipeline.replaceAll( - Collections.singletonList( - statefulParMultiDoOverride(inputsMaterialized, stateWriteInstant, nextFlushInstantFn))); - // collect all StateValues - List>> stateValues = new ArrayList<>(); - pipeline.traverseTopologically( - new PipelineVisitor.Defaults() { - @Override - public void visitPrimitiveTransform(TransformHierarchy.Node node) { - if (node.getTransform() instanceof ParDo.MultiOutput) { - node.getOutputs().entrySet().stream() - .filter(e -> e.getKey() instanceof StateTupleTag) - .map(Entry::getValue) - .findAny() - .ifPresent( - p -> - stateValues.add( - Pair.of(node.getFullName(), (PCollection) p))); - } - } - }); - if (!stateValues.isEmpty()) { - PCollectionList> list = PCollectionList.empty(pipeline); - for (Pair> p : stateValues) { - PCollection> mapped = p.getSecond().apply(WithKeys.of(p.getFirst())); - list = list.and(mapped); - } - list.apply(Flatten.pCollections()).apply(stateSink); - } - - return pipeline; - } - - private static void validatePipeline(Pipeline pipeline) { - // check that all nodes have unique names - Set names = new HashSet<>(); - pipeline.traverseTopologically( - new PipelineVisitor.Defaults() { - - @Override - public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) { - Preconditions.checkState(names.add(node.getFullName())); - return CompositeBehavior.ENTER_TRANSFORM; - } - - @Override - public void visitPrimitiveTransform(TransformHierarchy.Node node) { - Preconditions.checkState(names.add(node.getFullName())); - } - }); - } - - private static PTransformOverride statefulParMultiDoOverride( - PCollection> inputs, - Instant stateWriteInstant, - UnaryFunction nextFlushInstantFn) { - - return PTransformOverride.of( - application -> application.getTransform() instanceof ParDo.MultiOutput, - parMultiDoReplacementFactory(inputs, stateWriteInstant, nextFlushInstantFn)); - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private static PTransformOverrideFactory parMultiDoReplacementFactory( - PCollection> inputs, - Instant stateWriteInstant, - UnaryFunction nextFlushInstantFn) { - - return new PTransformOverrideFactory<>() { - @Override - public PTransformReplacement getReplacementTransform(AppliedPTransform transform) { - return replaceParMultiDo(transform, inputs, stateWriteInstant, nextFlushInstantFn); - } - - @SuppressWarnings("unchecked") - @Override - public Map, ReplacementOutput> mapOutputs(Map outputs, POutput newOutput) { - return ReplacementOutputs.tagged(outputs, newOutput); - } - }; - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private static PTransformReplacement replaceParMultiDo( - AppliedPTransform transform, - PCollection> inputs, - Instant stateWriteInstant, - UnaryFunction nextFlushInstantFn) { - - ParDo.MultiOutput rawTransform = - (ParDo.MultiOutput) (PTransform) transform.getTransform(); - DoFn, ?> doFn = (DoFn) rawTransform.getFn(); - PInput pMainInput = getMainInput(transform); - if (!DoFnSignatures.isStateful(doFn)) { - return PTransformReplacement.of(pMainInput, (PTransform) transform.getTransform()); - } - String transformName = transform.getFullName(); - PCollection transformInputs = - inputs - .apply(Filter.by(kv -> kv.getKey().equals(transformName))) - .apply(MapElements.into(TypeDescriptor.of(StateValue.class)).via(KV::getValue)); - TupleTag mainOutputTag = rawTransform.getMainOutputTag(); - return PTransformReplacement.of( - pMainInput, - transformedParDo( - transformInputs, - (DoFn) doFn, - mainOutputTag, - TupleTagList.of( - transform.getOutputs().keySet().stream() - .filter(t -> !t.equals(mainOutputTag)) - .collect(Collectors.toList())), - stateWriteInstant, - nextFlushInstantFn)); - } - - @SuppressWarnings("unchecked") - private static , OutputT> - PTransform, PCollectionTuple> transformedParDo( - PCollection transformInputs, - DoFn, OutputT> doFn, - TupleTag mainOutputTag, - TupleTagList otherOutputs, - Instant stateWriteInstant, - UnaryFunction nextFlushInstantFn) { - - return new PTransform<>() { - @Override - public PCollectionTuple expand(PCollection input) { - @SuppressWarnings("unchecked") - KvCoder coder = (KvCoder) input.getCoder(); - Coder keyCoder = coder.getKeyCoder(); - Coder valueCoder = coder.getValueCoder(); - TypeDescriptor> valueDescriptor = - new TypeDescriptor<>(new TypeToken>() {}) {}; - PCollection>> state = - transformInputs - .apply( - MapElements.into( - TypeDescriptors.kvs( - keyCoder.getEncodedTypeDescriptor(), valueDescriptor)) - .via( - e -> - ExceptionUtils.uncheckedFactory( - () -> - KV.of( - CoderUtils.decodeFromByteArray(keyCoder, e.getKey()), - StateOrInput.state(e))))) - .setCoder(KvCoder.of(keyCoder, StateOrInput.coder(valueCoder))); - PCollection>> inputs = - input - .apply( - MapElements.into( - TypeDescriptors.kvs( - keyCoder.getEncodedTypeDescriptor(), valueDescriptor)) - .via(e -> KV.of(e.getKey(), StateOrInput.input(e.getValue())))) - .setCoder(KvCoder.of(keyCoder, StateOrInput.coder(valueCoder))); - PCollection>> flattened = - PCollectionList.of(state).and(inputs).apply(Flatten.pCollections()); - PCollectionTuple tuple = - flattened.apply( - ParDo.of( - transformedDoFn( - doFn, - (KvCoder) input.getCoder(), - mainOutputTag, - stateWriteInstant, - nextFlushInstantFn)) - .withOutputTags(mainOutputTag, otherOutputs.and(STATE_TUPLE_TAG))); - PCollectionTuple res = PCollectionTuple.empty(input.getPipeline()); - for (Entry, PCollection> e : - (Set, PCollection>>) (Set) tuple.getAll().entrySet()) { - if (!e.getKey().equals(STATE_TUPLE_TAG)) { - res = res.and(e.getKey(), e.getValue()); - } - } - return res; - } - }; - } - - @VisibleForTesting - static >, OutputT> - DoFn transformedDoFn( - DoFn, OutputT> doFn, - KvCoder inputCoder, - TupleTag mainTag, - Instant stateWriteInstant, - UnaryFunction nextFlushInstantFn) { - - @SuppressWarnings("unchecked") - Class, OutputT>> doFnClass = - (Class, OutputT>>) doFn.getClass(); - - ClassLoadingStrategy strategy = ByteBuddyUtils.getClassLoadingStrategy(doFnClass); - final String className = - doFnClass.getName() - + "$Expanded" - + (Objects.hash(stateWriteInstant, nextFlushInstantFn) & Integer.MAX_VALUE); - final ClassLoader classLoader = ExternalStateExpander.class.getClassLoader(); - try { - @SuppressWarnings("unchecked") - Class> aClass = - (Class>) classLoader.loadClass(className); - // class found, return instance - return ExceptionUtils.uncheckedFactory( - () -> aClass.getConstructor(doFnClass).newInstance(doFn)); - } catch (ClassNotFoundException e) { - // class not found, create it - } - - ByteBuddy buddy = new ByteBuddy(); - @SuppressWarnings("unchecked") - ParameterizedType parameterizedSuperClass = - getParameterizedDoFn((Class, OutputT>>) doFn.getClass()); - ParameterizedType inputType = - (ParameterizedType) parameterizedSuperClass.getActualTypeArguments()[0]; - Preconditions.checkArgument( - inputType.getRawType().equals(KV.class), - "Input type to stateful DoFn must be KV, go %s", - inputType); - - Type outputType = parameterizedSuperClass.getActualTypeArguments()[1]; - Generic wrapperInput = getWrapperInputType(inputType); - - Generic doFnGeneric = - Generic.Builder.parameterizedType( - TypeDescription.ForLoadedType.of(DoFn.class), - wrapperInput, - TypeDescription.Generic.Builder.of(outputType).build()) - .build(); - @SuppressWarnings("unchecked") - Builder> builder = - (Builder>) - buddy - .subclass(doFnGeneric) - .name(className) - .defineField(DELEGATE_FIELD_NAME, doFnClass, Visibility.PRIVATE); - builder = addStateAndTimers(doFnClass, inputType, builder); - builder = - builder - .defineConstructor(Visibility.PUBLIC) - .withParameters(doFnClass) - .intercept( - addStateAndTimerValues( - doFn, - inputCoder, - MethodCall.invoke( - ExceptionUtils.uncheckedFactory(() -> DoFn.class.getConstructor())) - .andThen(FieldAccessor.ofField(DELEGATE_FIELD_NAME).setsArgumentAt(0)))); - - builder = - addProcessingMethods( - doFn, - inputType, - inputCoder.getKeyCoder(), - mainTag, - outputType, - stateWriteInstant, - nextFlushInstantFn, - buddy, - builder); - Unloaded> dynamicClass = builder.make(); - return ExceptionUtils.uncheckedFactory( - () -> - dynamicClass - .load(null, strategy) - .getLoaded() - .getDeclaredConstructor(doFnClass) - .newInstance(doFn)); - } - - private static , OutputT> Implementation addStateAndTimerValues( - DoFn doFn, Coder> inputCoder, Composable delegate) { - - List> acceptable = Arrays.asList(StateId.class, TimerId.class); - @SuppressWarnings("unchecked") - Class> doFnClass = - (Class>) doFn.getClass(); - for (Field f : doFnClass.getDeclaredFields()) { - if (!Modifier.isStatic(f.getModifiers()) - && acceptable.stream().anyMatch(a -> f.getAnnotation(a) != null)) { - f.setAccessible(true); - Object value = ExceptionUtils.uncheckedFactory(() -> f.get(doFn)); - delegate = delegate.andThen(FieldAccessor.ofField(f.getName()).setsValue(value)); - } - } - delegate = - delegate - .andThen( - FieldAccessor.ofField(EXPANDER_BUF_STATE_SPEC) - .setsValue(StateSpecs.bag(TimestampedValueCoder.of(inputCoder)))) - .andThen(FieldAccessor.ofField(EXPANDER_FLUSH_STATE_SPEC).setsValue(StateSpecs.value())) - .andThen( - FieldAccessor.ofField(EXPANDER_TIMER_SPEC) - .setsValue(TimerSpecs.timer(TimeDomain.EVENT_TIME))); - return delegate; - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private static ParameterizedType getParameterizedDoFn( - Class> doFnClass) { - - Type type = doFnClass.getGenericSuperclass(); - if (type instanceof ParameterizedType) { - return (ParameterizedType) type; - } - if (doFnClass.getSuperclass().isAssignableFrom(DoFn.class)) { - return getParameterizedDoFn((Class) doFnClass.getGenericSuperclass()); - } - throw new IllegalStateException("Cannot get parameterized type of " + doFnClass); - } - - private static >, OutputT> - Builder> addProcessingMethods( - DoFn, OutputT> doFn, - ParameterizedType inputType, - Coder keyCoder, - TupleTag mainTag, - Type outputType, - Instant stateWriteInstant, - UnaryFunction nextFlushInstantFn, - ByteBuddy buddy, - Builder> builder) { - - builder = addProcessingMethod(doFn, DoFn.Setup.class, builder); - builder = addProcessingMethod(doFn, DoFn.StartBundle.class, builder); - builder = - addProcessElementMethod( - doFn, inputType, mainTag, outputType, stateWriteInstant, buddy, builder); - builder = addProcessingMethod(doFn, DoFn.FinishBundle.class, builder); - builder = addProcessingMethod(doFn, DoFn.Teardown.class, builder); - builder = addOnWindowExpirationMethod(doFn, inputType, mainTag, buddy, builder); - builder = addProcessingMethod(doFn, DoFn.GetInitialRestriction.class, builder); - builder = addProcessingMethod(doFn, DoFn.SplitRestriction.class, builder); - builder = addProcessingMethod(doFn, DoFn.GetRestrictionCoder.class, builder); - builder = addProcessingMethod(doFn, DoFn.GetWatermarkEstimatorStateCoder.class, builder); - builder = addProcessingMethod(doFn, DoFn.GetInitialWatermarkEstimatorState.class, builder); - builder = addProcessingMethod(doFn, DoFn.NewWatermarkEstimator.class, builder); - builder = addProcessingMethod(doFn, DoFn.NewTracker.class, builder); - builder = addProcessingMethod(doFn, DoFn.OnTimer.class, builder); - builder = - addTimerFlushMethod( - doFn, - inputType, - keyCoder, - mainTag, - STATE_TUPLE_TAG, - outputType, - nextFlushInstantFn, - buddy, - builder); - return builder; - } - - private static >, OutputT> - Builder> addProcessElementMethod( - DoFn, OutputT> doFn, - ParameterizedType inputType, - TupleTag mainTag, - Type outputType, - Instant stateWriteInstant, - ByteBuddy buddy, - Builder> builder) { - - Class annotation = ProcessElement.class; - Method method = findMethod(doFn, annotation); - if (method != null) { - ProcessElementParameterExpander expander = - ProcessElementParameterExpander.of( - doFn, method, inputType, mainTag, outputType, stateWriteInstant); - List> wrapperArgs = expander.getWrapperArgs(); - MethodDefinition> methodDefinition = - builder - .defineMethod(method.getName(), method.getReturnType(), Visibility.PUBLIC) - .withParameters( - wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) - .intercept( - MethodDelegation.to( - new ProcessElementInterceptor<>(doFn, expander, method, buddy))); - - for (int i = 0; i < wrapperArgs.size(); i++) { - Pair arg = wrapperArgs.get(i); - if (arg.getFirst() != null) { - methodDefinition = methodDefinition.annotateParameter(i, arg.getFirst()); - } - } - return methodDefinition.annotateMethod( - AnnotationDescription.Builder.ofType(annotation).build()); - } - return builder; - } - - private static >, OutputT> - Builder> addOnWindowExpirationMethod( - DoFn, OutputT> doFn, - ParameterizedType inputType, - TupleTag mainTag, - ByteBuddy buddy, - Builder> builder) { - - Class annotation = DoFn.OnWindowExpiration.class; - @Nullable Method onWindowExpirationMethod = findMethod(doFn, annotation); - Method processElementMethod = findMethod(doFn, DoFn.ProcessElement.class); - Type outputType = doFn.getOutputTypeDescriptor().getType(); - if (processElementMethod != null) { - OnWindowParameterExpander expander = - OnWindowParameterExpander.of( - inputType, processElementMethod, onWindowExpirationMethod, mainTag, outputType); - List> wrapperArgs = expander.getWrapperArgs(); - MethodDefinition> methodDefinition = - builder - .defineMethod( - onWindowExpirationMethod == null - ? "_onWindowExpiration" - : onWindowExpirationMethod.getName(), - void.class, - Visibility.PUBLIC) - .withParameters( - wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) - .intercept( - MethodDelegation.to( - new OnWindowExpirationInterceptor<>( - doFn, processElementMethod, onWindowExpirationMethod, expander, buddy))); - - // retrieve parameter annotations and apply them - for (int i = 0; i < wrapperArgs.size(); i++) { - AnnotationDescription ann = wrapperArgs.get(i).getFirst(); - if (ann != null) { - methodDefinition = methodDefinition.annotateParameter(i, ann); - } - } - return methodDefinition.annotateMethod( - AnnotationDescription.Builder.ofType(annotation).build()); - } - return builder; - } - - private static >> - Builder> addTimerFlushMethod( - DoFn, OutputT> doFn, - ParameterizedType inputType, - Coder keyCoder, - TupleTag mainTag, - TupleTag stateTag, - Type outputType, - UnaryFunction nextFlushInstantFn, - ByteBuddy buddy, - Builder> builder) { - - Method processElement = findMethod(doFn, ProcessElement.class); - FlushTimerParameterExpander expander = - FlushTimerParameterExpander.of(doFn, inputType, processElement, mainTag, outputType); - List> wrapperArgs = expander.getWrapperArgs(); - MethodDefinition> methodDefinition = - builder - .defineMethod("expanderFlushTimer", void.class, Visibility.PUBLIC) - .withParameters(wrapperArgs.stream().map(Pair::getSecond).collect(Collectors.toList())) - .intercept( - MethodDelegation.to( - new FlushTimerInterceptor<>( - doFn, - processElement, - expander, - keyCoder, - stateTag, - nextFlushInstantFn, - buddy))); - int i = 0; - for (Pair p : wrapperArgs) { - if (p.getFirst() != null) { - methodDefinition = methodDefinition.annotateParameter(i, p.getFirst()); - } - i++; - } - return methodDefinition.annotateMethod( - AnnotationDescription.Builder.ofType(DoFn.OnTimer.class) - .define("value", EXPANDER_TIMER_NAME) - .build()); - } - - private static Method findMethod( - DoFn, OutputT> doFn, Class annotation) { - - return Iterables.getOnlyElement( - Arrays.stream(doFn.getClass().getMethods()) - .filter(m -> m.getAnnotation(annotation) != null) - .collect(Collectors.toList()), - null); - } - - private static >, OutputT, T extends Annotation> - Builder> addProcessingMethod( - DoFn, OutputT> doFn, - Class annotation, - Builder> builder) { - - Method method = findMethod(doFn, annotation); - if (method != null) { - MethodDefinition> methodDefinition = - builder - .defineMethod(method.getName(), method.getReturnType(), Visibility.PUBLIC) - .withParameters(method.getGenericParameterTypes()) - .intercept(MethodCall.invoke(method).onField(DELEGATE_FIELD_NAME).withAllArguments()); - - // retrieve parameter annotations and apply them - Annotation[][] parameterAnnotations = method.getParameterAnnotations(); - for (int i = 0; i < parameterAnnotations.length; i++) { - for (Annotation paramAnnotation : parameterAnnotations[i]) { - methodDefinition = methodDefinition.annotateParameter(i, paramAnnotation); - } - } - return methodDefinition.annotateMethod( - AnnotationDescription.Builder.ofType(annotation).build()); - } - return builder; - } - - private static >, OutputT> - Builder> addStateAndTimers( - Class, OutputT>> doFnClass, - ParameterizedType inputType, - Builder> builder) { - - builder = cloneFields(doFnClass, StateId.class, builder); - builder = cloneFields(doFnClass, TimerId.class, builder); - builder = addBufferingStatesAndTimer(inputType, builder); - return builder; - } - - /** Add state that buffers inputs until we process all state updates. */ - private static >, OutputT> - Builder> addBufferingStatesAndTimer( - ParameterizedType inputType, Builder> builder) { - - // type: StateSpec>> - Generic bufStateSpecFieldType = - Generic.Builder.parameterizedType( - TypeDescription.ForLoadedType.of(StateSpec.class), bagStateFromInputType(inputType)) - .build(); - // type: StateSpec> - Generic finishedStateSpecFieldType = - Generic.Builder.parameterizedType( - TypeDescription.ForLoadedType.of(StateSpec.class), - Generic.Builder.parameterizedType(ValueState.class, Instant.class).build()) - .build(); - - Generic timerSpecFieldType = Generic.Builder.of(TimerSpec.class).build(); - - builder = - defineStateField( - builder, - bufStateSpecFieldType, - DoFn.StateId.class, - EXPANDER_BUF_STATE_SPEC, - EXPANDER_BUF_STATE_NAME); - builder = - defineStateField( - builder, - finishedStateSpecFieldType, - DoFn.StateId.class, - EXPANDER_FLUSH_STATE_SPEC, - EXPANDER_FLUSH_STATE_NAME); - builder = - defineStateField( - builder, - timerSpecFieldType, - DoFn.TimerId.class, - EXPANDER_TIMER_SPEC, - EXPANDER_TIMER_NAME); - - return builder; - } - - private static >, OutputT> - Builder> defineStateField( - Builder> builder, - Generic stateSpecFieldType, - Class annotation, - String fieldName, - String name) { - - return builder - .defineField( - fieldName, - stateSpecFieldType, - Visibility.PUBLIC.getMask() + FieldManifestation.FINAL.getMask()) - .annotateField( - AnnotationDescription.Builder.ofType(annotation).define("value", name).build()); + PTransform>, PDone> stateSink) + throws IOException { + + ExpandContext context = + new ExpandContext(inputs, stateWriteInstant, nextFlushInstantFn, stateSink); + Pipeline expanded = context.expand(pipeline); + File dynamicJar = RunnerUtils.createJarFromDynamicClasses(context.getGeneratedClasses()); + RunnerUtils.registerToPipeline( + expanded.getOptions(), + expanded.getOptions().getRunner().getSimpleName(), + Collections.singletonList(dynamicJar)); + return expanded; } - private static >, OutputT, T extends Annotation> - Builder> cloneFields( - Class, OutputT>> doFnClass, - Class annotationClass, - Builder> builder) { - - for (Field f : doFnClass.getDeclaredFields()) { - if (!Modifier.isStatic(f.getModifiers()) && f.getAnnotation(annotationClass) != null) { - builder = - builder - .defineField(f.getName(), f.getGenericType(), f.getModifiers()) - .annotateField(f.getDeclaredAnnotations()); - } - } - return builder; - } - - private static PInput getMainInput(AppliedPTransform transform) { - Map, PCollection> mainInputs = transform.getMainInputs(); - if (mainInputs.size() == 1) { - return Iterables.getOnlyElement(mainInputs.values()); - } - return asTuple(mainInputs); - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private static @NonNull PInput asTuple(Map, PCollection> mainInputs) { - Preconditions.checkArgument(!mainInputs.isEmpty()); - PCollectionTuple res = null; - for (Map.Entry, PCollection> e : mainInputs.entrySet()) { - if (res == null) { - res = PCollectionTuple.of((TupleTag) e.getKey(), e.getValue()); - } else { - res = res.and((TupleTag) e.getKey(), e.getValue()); - } - } - return Objects.requireNonNull(res); - } - - private static class ProcessElementInterceptor { - - private final DoFn, ?> doFn; - private final ProcessElementParameterExpander expander; - private final UnaryFunction processFn; - private final VoidMethodInvoker invoker; - - ProcessElementInterceptor( - DoFn, ?> doFn, - ProcessElementParameterExpander expander, - Method process, - ByteBuddy buddy) { - - this.doFn = doFn; - this.expander = expander; - this.processFn = expander.getProcessFn(); - this.invoker = ExceptionUtils.uncheckedFactory(() -> VoidMethodInvoker.of(process, buddy)); - } - - @RuntimeType - public void intercept( - @This DoFn>, ?> proxy, @AllArguments Object[] allArgs) { - - if (processFn.apply(allArgs)) { - Object[] methodArgs = expander.getProcessElementArgs(allArgs); - ExceptionUtils.unchecked(() -> invoker.invoke(doFn, methodArgs)); - } - } - } - - private static class OnWindowExpirationInterceptor { - private final DoFn, ?> doFn; - private final VoidMethodInvoker, ?>> processElement; - private final @Nullable VoidMethodInvoker, ?>> onWindowExpiration; - private final OnWindowParameterExpander expander; - - public OnWindowExpirationInterceptor( - DoFn, ?> doFn, - Method processElementMethod, - @Nullable Method onWindowExpirationMethod, - OnWindowParameterExpander expander, - ByteBuddy buddy) { - - this.doFn = doFn; - this.processElement = - ExceptionUtils.uncheckedFactory(() -> VoidMethodInvoker.of(processElementMethod, buddy)); - this.onWindowExpiration = - onWindowExpirationMethod == null - ? null - : ExceptionUtils.uncheckedFactory( - () -> VoidMethodInvoker.of(onWindowExpirationMethod, buddy)); - this.expander = expander; - } - - @RuntimeType - public void intercept( - @This DoFn>, ?> proxy, @AllArguments Object[] allArgs) { - - @SuppressWarnings("unchecked") - BagState>> buf = - (BagState>>) allArgs[allArgs.length - 1]; - Iterable>> buffered = buf.read(); - // feed all data to @ProcessElement - for (TimestampedValue> kv : buffered) { - ExceptionUtils.unchecked( - () -> processElement.invoke(doFn, expander.getProcessElementArgs(kv, allArgs))); - } - // invoke onWindowExpiration - if (onWindowExpiration != null) { - ExceptionUtils.unchecked( - () -> onWindowExpiration.invoke(doFn, expander.getOnWindowExpirationArgs(allArgs))); - } - } - } - - private static class FlushTimerInterceptor { - - private final DoFn, ?> doFn; - private final LinkedHashMap>> - stateReaders; - private final VoidMethodInvoker processElementMethod; - private final FlushTimerParameterExpander expander; - private final Coder keyCoder; - private final TupleTag stateTag; - private final UnaryFunction nextFlushInstantFn; - - FlushTimerInterceptor( - DoFn, ?> doFn, - Method processElementMethod, - FlushTimerParameterExpander expander, - Coder keyCoder, - TupleTag stateTag, - UnaryFunction nextFlushInstantFn, - ByteBuddy buddy) { - - this.doFn = doFn; - this.stateReaders = getStateReaders(doFn); - this.processElementMethod = - ExceptionUtils.uncheckedFactory(() -> VoidMethodInvoker.of(processElementMethod, buddy)); - this.expander = expander; - this.keyCoder = keyCoder; - this.stateTag = stateTag; - this.nextFlushInstantFn = nextFlushInstantFn; - } - - @RuntimeType - public void intercept(@This DoFn>, ?> doFn, @AllArguments Object[] args) { - Instant now = (Instant) args[args.length - 6]; - @SuppressWarnings("unchecked") - K key = (K) args[args.length - 5]; - @SuppressWarnings("unchecked") - ValueState nextFlushState = (ValueState) args[args.length - 3]; - Timer flushTimer = (Timer) args[args.length - 4]; - Instant nextFlush = nextFlushInstantFn.apply(now); - Instant lastFlush = nextFlushState.read(); - boolean isNextScheduled = - nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE); - if (isNextScheduled) { - flushTimer.set(nextFlush); - nextFlushState.write(nextFlush); - } - @SuppressWarnings("unchecked") - BagState>> bufState = - (BagState>>) args[args.length - 2]; - @SuppressWarnings({"unchecked", "rawtypes"}) - List>> pushedBackElements = - processBuffer(args, (Iterable) bufState.read(), MoreObjects.firstNonNull(lastFlush, now)); - bufState.clear(); - // if we have already processed state data - if (lastFlush != null) { - MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1]; - OutputReceiver output = outputReceiver.get(stateTag); - byte[] keyBytes = - ExceptionUtils.uncheckedFactory(() -> CoderUtils.encodeToByteArray(keyCoder, key)); - int i = 0; - for (BiFunction> f : stateReaders.values()) { - Object accessor = args[i++]; - Iterable values = f.apply(accessor, keyBytes); - values.forEach(output::output); - } - } - @SuppressWarnings({"unchecked", "rawtypes"}) - List>> remaining = - processBuffer( - args, - (List) pushedBackElements, - MoreObjects.firstNonNull(nextFlush, BoundedWindow.TIMESTAMP_MAX_VALUE)); - remaining.forEach(bufState::add); - } - - private List>> processBuffer( - Object[] args, Iterable>> buffer, Instant maxTs) { - - List>> pushedBackElements = new ArrayList<>(); - buffer.forEach( - kv -> { - if (kv.getTimestamp().isBefore(maxTs)) { - Object[] processArgs = expander.getProcessElementArgs(kv, args); - ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs)); - } else { - // return back to buffer - log.debug("Returning element {} to flush buffer", kv); - pushedBackElements.add(kv); - } - }); - @SuppressWarnings({"unchecked", "rawtypes"}) - List>> res = (List) pushedBackElements; - return res; - } - } - - static Generic bagStateFromInputType(ParameterizedType inputType) { - return Generic.Builder.parameterizedType( - TypeDescription.ForLoadedType.of(BagState.class), - Generic.Builder.parameterizedType(TimestampedValue.class, inputType).build()) - .build(); - } - - private static class StateTupleTag extends TupleTag {} - // do not construct private ExternalStateExpander() {} } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java index 3d81293cd..38a2ae2eb 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java @@ -15,11 +15,13 @@ */ package cz.o2.proxima.beam.util.state; -import static cz.o2.proxima.beam.util.state.ExternalStateExpander.*; +import static cz.o2.proxima.beam.util.state.ExpandContext.*; import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; +import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; @@ -40,7 +42,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; -interface FlushTimerParameterExpander { +interface FlushTimerParameterExpander extends Serializable { static FlushTimerParameterExpander of( DoFn doFn, @@ -49,24 +51,13 @@ static FlushTimerParameterExpander of( TupleTag mainTag, Type outputType) { - final LinkedHashMap> processArgs = extractArgs(processElement); - final LinkedHashMap> wrapperArgs = + final LinkedHashMap> createdProcessArgs = + extractArgs(processElement); + final LinkedHashMap> createdWrapperArgs = createWrapperArgs(doFn, inputType); - final List>, Object>> - processArgsGenerators = projectArgs(wrapperArgs, processArgs, mainTag, outputType); - - return new FlushTimerParameterExpander() { - @Override - public List> getWrapperArgs() { - return new ArrayList<>(wrapperArgs.values()); - } - - @Override - public Object[] getProcessElementArgs( - TimestampedValue> input, Object[] wrapperArgs) { - return fromGenerators(input, processArgsGenerators, wrapperArgs); - } - }; + + return new FlushTimerParameterExpanderImpl( + extractArgs(processElement), createWrapperArgs(doFn, inputType), mainTag, outputType); } private static LinkedHashMap> @@ -147,4 +138,32 @@ public Object[] getProcessElementArgs( * {@code @}OnWindowExpiration */ Object[] getProcessElementArgs(TimestampedValue> input, Object[] wrapperArgs); + + class FlushTimerParameterExpanderImpl implements FlushTimerParameterExpander { + + final transient LinkedHashMap> processArgs; + final transient LinkedHashMap> wrapperArgs; + final List>, Object>> processArgsGenerators; + + private FlushTimerParameterExpanderImpl( + LinkedHashMap> processArgs, + LinkedHashMap> wrapperArgs, + TupleTag mainTag, + Type outputType) { + + this.processArgs = processArgs; + this.wrapperArgs = wrapperArgs; + processArgsGenerators = projectArgs(wrapperArgs, processArgs, mainTag, outputType); + } + + @Override + public List> getWrapperArgs() { + return new ArrayList<>(wrapperArgs.values()); + } + + @Override + public Object[] getProcessElementArgs(TimestampedValue> input, Object[] wrapperArgs) { + return fromGenerators(input, processArgsGenerators, wrapperArgs); + } + } } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java index 26daf5c57..b89d4d4fd 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java @@ -16,11 +16,13 @@ package cz.o2.proxima.beam.util.state; import cz.o2.proxima.core.functional.BiConsumer; +import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.util.ExceptionUtils; import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting; import cz.o2.proxima.internal.com.google.common.base.Preconditions; import cz.o2.proxima.internal.com.google.common.collect.Iterables; +import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; @@ -36,7 +38,6 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.bytebuddy.ByteBuddy; @@ -48,6 +49,7 @@ import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.description.type.TypeDescription.Generic; import net.bytebuddy.description.type.TypeDescription.Generic.Builder; +import net.bytebuddy.dynamic.DynamicType.Unloaded; import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; import net.bytebuddy.dynamic.scaffold.InstrumentedType; import net.bytebuddy.implementation.Implementation; @@ -123,13 +125,13 @@ static Object[] fromGenerators( wrapperArgList.values().stream() .map(p -> p.getFirst() != null ? p.getFirst().getAnnotationType() : p.getSecond()) .collect(Collectors.toList()); - for (Map.Entry> e : argsMap.entrySet()) { - int wrapperArg = findArgIndex(wrapperArgList.keySet(), e.getKey()); + for (TypeId t : argsMap.keySet()) { + int wrapperArg = findArgIndex(wrapperArgList.keySet(), t); if (wrapperArg < 0) { addUnknownWrapperArgument( - mainTag, outputType, e.getKey(), wrapperParamsIds, wrapperArgList.keySet(), res); + mainTag, outputType, t, wrapperParamsIds, wrapperArgList.keySet(), res); } else { - remapKnownArgument(outputType, e.getKey(), wrapperArg, res); + remapKnownArgument(outputType, t, wrapperArg, res); } } return res; @@ -747,35 +749,35 @@ public List resolve( } } - public interface MethodInvoker { + public interface MethodInvoker extends Serializable { - static MethodInvoker of(Method method, ByteBuddy buddy) + static MethodInvoker of(Method method, ByteBuddy buddy, ClassCollector collector) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { - return getInvoker(method, buddy); + return getInvoker(method, buddy, collector); } R invoke(T _this, Object[] args); } - public interface VoidMethodInvoker { + public interface VoidMethodInvoker extends Serializable { - static VoidMethodInvoker of(Method method, ByteBuddy buddy) + static VoidMethodInvoker of(Method method, ByteBuddy buddy, ClassCollector collector) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { - return getInvoker(method, buddy); + return getInvoker(method, buddy, collector); } void invoke(T _this, Object[] args); } - private static T getInvoker(Method method, ByteBuddy buddy) + private static T getInvoker(Method method, ByteBuddy buddy, ClassCollector collector) throws NoSuchMethodException, InvocationTargetException, InstantiationException, @@ -798,21 +800,18 @@ private static T getInvoker(Method method, ByteBuddy buddy) } catch (Exception ex) { // define the class } + Unloaded unloaded = + buddy + .subclass(superClass) + .implement(implement) + .name(subclassName) + .defineMethod("invoke", returnType, Visibility.PUBLIC) + .withParameters(declaringClass, Object[].class) + .intercept(MethodCall.invoke(method).onArgument(0).with(new ArrayArgumentProvider())) + .make(); @SuppressWarnings("unchecked") - Class cls = - (Class) - buddy - .subclass(superClass) - .implement(implement) - .name(subclassName) - .defineMethod("invoke", returnType, Visibility.PUBLIC) - .withParameters(declaringClass, Object[].class) - .intercept( - MethodCall.invoke(method).onArgument(0).with(new ArrayArgumentProvider())) - .make() - .load(null, strategy) - .getLoaded(); - + Class cls = (Class) unloaded.load(null, strategy).getLoaded(); + collector.collect(cls, unloaded.getBytes()); return newInstance(cls); } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java index ebb7a1ae7..86530808f 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java @@ -15,13 +15,15 @@ */ package cz.o2.proxima.beam.util.state; -import static cz.o2.proxima.beam.util.state.ExternalStateExpander.bagStateFromInputType; +import static cz.o2.proxima.beam.util.state.ExpandContext.bagStateFromInputType; import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; +import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.base.MoreObjects; import cz.o2.proxima.internal.com.google.common.base.Preconditions; import cz.o2.proxima.internal.com.google.common.collect.Sets; +import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; @@ -32,7 +34,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.function.BiFunction; import java.util.stream.Collectors; import net.bytebuddy.description.annotation.AnnotationDescription; import net.bytebuddy.description.annotation.AnnotationDescription.Builder; @@ -45,7 +46,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.checkerframework.checker.nullness.qual.Nullable; -interface OnWindowParameterExpander { +interface OnWindowParameterExpander extends Serializable { static OnWindowParameterExpander of( ParameterizedType inputType, @@ -66,23 +67,8 @@ static OnWindowParameterExpander of( final List>, Object>> windowArgsGenerators = projectArgs(wrapperArgs, onWindowArgs, mainTag, outputType); - return new OnWindowParameterExpander() { - @Override - public List> getWrapperArgs() { - return new ArrayList<>(wrapperArgs.values()); - } - - @Override - public Object[] getProcessElementArgs( - TimestampedValue> input, Object[] wrapperArgs) { - return fromGenerators(input, processArgsGenerators, wrapperArgs); - } - - @Override - public Object[] getOnWindowExpirationArgs(Object[] wrapperArgs) { - return fromGenerators(null, windowArgsGenerators, wrapperArgs); - } - }; + return new OnWindowParameterExpanderImpl( + wrapperArgs, processArgsGenerators, windowArgsGenerators); } static List> createWrapperArgList( @@ -133,7 +119,7 @@ static LinkedHashMap> create // add @StateId for buffer AnnotationDescription buffer = Builder.ofType(StateId.class) - .define("value", ExternalStateExpander.EXPANDER_BUF_STATE_NAME) + .define("value", ExpandContext.EXPANDER_BUF_STATE_NAME) .build(); res.put(TypeId.of(buffer), Pair.of(buffer, bagStateFromInputType(inputType))); return res; @@ -153,4 +139,39 @@ static LinkedHashMap> create /** Get parameters that should be passed to {@code @}OnWindowExpiration from wrapper's call. */ Object[] getOnWindowExpirationArgs(Object[] wrapperArgs); + + class OnWindowParameterExpanderImpl implements OnWindowParameterExpander { + + private final transient LinkedHashMap> + wrapperArgs; + private final List>, Object>> + processArgsGenerators; + private final List>, Object>> + windowArgsGenerators; + + public OnWindowParameterExpanderImpl( + LinkedHashMap> wrapperArgs, + List>, Object>> processArgsGenerators, + List>, Object>> windowArgsGenerators) { + + this.wrapperArgs = wrapperArgs; + this.processArgsGenerators = processArgsGenerators; + this.windowArgsGenerators = windowArgsGenerators; + } + + @Override + public List> getWrapperArgs() { + return new ArrayList<>(wrapperArgs.values()); + } + + @Override + public Object[] getProcessElementArgs(TimestampedValue> input, Object[] wrapperArgs) { + return fromGenerators(input, processArgsGenerators, wrapperArgs); + } + + @Override + public Object[] getOnWindowExpirationArgs(Object[] wrapperArgs) { + return fromGenerators(null, windowArgsGenerators, wrapperArgs); + } + } } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java index 02c547d80..e94a16dca 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ProcessElementParameterExpander.java @@ -19,9 +19,11 @@ import static cz.o2.proxima.beam.util.state.MethodCallUtils.projectArgs; import cz.o2.proxima.core.functional.BiConsumer; +import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.functional.UnaryFunction; import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; @@ -32,8 +34,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.BiFunction; -import java.util.function.Predicate; import lombok.extern.slf4j.Slf4j; import net.bytebuddy.description.annotation.AnnotationDescription; import net.bytebuddy.description.annotation.AnnotationDescription.ForLoadedAnnotation; @@ -53,7 +53,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; -interface ProcessElementParameterExpander { +interface ProcessElementParameterExpander extends Serializable { static ProcessElementParameterExpander of( DoFn doFn, @@ -63,28 +63,8 @@ static ProcessElementParameterExpander of( Type outputType, Instant stateWriteInstant) { - final LinkedHashMap> processArgs = extractArgs(processElement); - final LinkedHashMap> wrapperArgs = - createWrapperArgs(inputType, outputType, processArgs.values()); - final List>, Object>> processArgsGenerators = - projectArgs(wrapperArgs, processArgs, mainTag, outputType); - - return new ProcessElementParameterExpander() { - @Override - public List> getWrapperArgs() { - return new ArrayList<>(wrapperArgs.values()); - } - - @Override - public Object[] getProcessElementArgs(Object[] wrapperArgs) { - return fromGenerators(processArgsGenerators, wrapperArgs); - } - - @Override - public UnaryFunction getProcessFn() { - return createProcessFn(wrapperArgs, doFn, processElement, stateWriteInstant); - } - }; + return new ProcessElementParameterExpanderImpl( + doFn, processElement, inputType, mainTag, outputType, stateWriteInstant); } /** Get arguments that must be declared by wrapper's call. */ @@ -109,10 +89,11 @@ private static UnaryFunction createProcessFn( return new ProcessFn(stateWriteInstant, wrapperArgs, method, stateUpdaterMap); } - private static int findParameter(Collection args, Predicate predicate) { + private static int findParameter( + Collection args, UnaryFunction predicate) { int i = 0; for (TypeId t : args) { - if (predicate.test(t)) { + if (predicate.apply(t)) { return i; } i++; @@ -140,7 +121,7 @@ static LinkedHashMap> create // add @TimerId for flush timer AnnotationDescription timerAnnotation = AnnotationDescription.Builder.ofType(DoFn.TimerId.class) - .define("value", ExternalStateExpander.EXPANDER_TIMER_NAME) + .define("value", ExpandContext.EXPANDER_TIMER_NAME) .build(); res.put( TypeId.of(timerAnnotation), @@ -149,7 +130,7 @@ static LinkedHashMap> create // add @StateId for finished buffer AnnotationDescription finishedAnnotation = AnnotationDescription.Builder.ofType(DoFn.StateId.class) - .define("value", ExternalStateExpander.EXPANDER_FLUSH_STATE_NAME) + .define("value", ExpandContext.EXPANDER_FLUSH_STATE_NAME) .build(); res.put( TypeId.of(finishedAnnotation), @@ -161,11 +142,11 @@ static LinkedHashMap> create // add @StateId for buffer AnnotationDescription stateAnnotation = AnnotationDescription.Builder.ofType(StateId.class) - .define("value", ExternalStateExpander.EXPANDER_BUF_STATE_NAME) + .define("value", ExpandContext.EXPANDER_BUF_STATE_NAME) .build(); res.put( TypeId.of(stateAnnotation), - Pair.of(stateAnnotation, ExternalStateExpander.bagStateFromInputType(inputType))); + Pair.of(stateAnnotation, ExpandContext.bagStateFromInputType(inputType))); // add MultiOutputReceiver TypeDescription receiver = ForLoadedType.of(MultiOutputReceiver.class); @@ -191,8 +172,8 @@ static Pair> transformProces class ProcessFn implements UnaryFunction { private final int elementPos; private final Instant stateWriteInstant; - private final LinkedHashMap> wrapperArgs; - private final Method method; + private final List wrapperArgsKeys; + private final int methodParameterCount; private final Map> stateUpdaterMap; public ProcessFn( @@ -203,8 +184,8 @@ public ProcessFn( this.elementPos = findParameter(wrapperArgs.keySet(), TypeId::isElement); this.stateWriteInstant = stateWriteInstant; - this.wrapperArgs = wrapperArgs; - this.method = method; + this.wrapperArgsKeys = new ArrayList<>(wrapperArgs.keySet()); + this.methodParameterCount = method.getParameterCount(); this.stateUpdaterMap = stateUpdaterMap; Preconditions.checkState(elementPos >= 0, "Missing @Element annotation on method %s", method); } @@ -222,9 +203,9 @@ public Boolean apply(Object[] args) { StateValue state = elem.getValue().getState(); String stateName = state.getName(); // find state accessor - int statePos = findParameter(wrapperArgs.keySet(), a -> a.isState(stateName)); + int statePos = findParameter(wrapperArgsKeys, a -> a.isState(stateName)); Preconditions.checkArgument( - statePos < method.getParameterCount(), "Missing state accessor for %s", stateName); + statePos < methodParameterCount, "Missing state accessor for %s", stateName); Object stateAccessor = args[statePos]; // find declaration of state to find coder BiConsumer updater = stateUpdaterMap.get(stateName); @@ -253,4 +234,41 @@ public Boolean apply(Object[] args) { return true; } } + + class ProcessElementParameterExpanderImpl implements ProcessElementParameterExpander { + + final transient LinkedHashMap> processArgs; + final transient LinkedHashMap> wrapperArgs; + final List>, Object>> processArgsGenerators; + final UnaryFunction processFn; + + public ProcessElementParameterExpanderImpl( + DoFn doFn, + Method processElement, + ParameterizedType inputType, + TupleTag mainTag, + Type outputType, + Instant stateWriteInstant) { + + processArgs = extractArgs(processElement); + wrapperArgs = createWrapperArgs(inputType, outputType, processArgs.values()); + processArgsGenerators = projectArgs(wrapperArgs, processArgs, mainTag, outputType); + processFn = createProcessFn(wrapperArgs, doFn, processElement, stateWriteInstant); + } + + @Override + public List> getWrapperArgs() { + return new ArrayList<>(wrapperArgs.values()); + } + + @Override + public Object[] getProcessElementArgs(Object[] wrapperArgs) { + return fromGenerators(processArgsGenerators, wrapperArgs); + } + + @Override + public UnaryFunction getProcessFn() { + return processFn; + } + } } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java index 028d65ba8..2b8610638 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java @@ -17,6 +17,7 @@ import cz.o2.proxima.internal.com.google.common.base.MoreObjects; import cz.o2.proxima.internal.com.google.common.base.Preconditions; +import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Type; import net.bytebuddy.description.annotation.AnnotationDescription; @@ -26,7 +27,7 @@ import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.StateId; -class TypeId { +class TypeId implements Serializable { private static final TypeId TIMESTAMP_TYPE = TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Timestamp.class).build()); diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java index e2da7fd36..9e973562f 100644 --- a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java @@ -21,6 +21,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import cz.o2.proxima.core.util.SerializableScopedValue; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -32,7 +33,6 @@ import org.apache.beam.runners.flink.FlinkRunner; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; @@ -79,7 +79,7 @@ public static List>> params() { @Parameter public Class> runner; @Test - public void testSimpleExpand() { + public void testSimpleExpand() throws IOException { Pipeline pipeline = createPipeline(); PCollection inputs = pipeline.apply(Create.of("1", "2", "3")); PCollection> withKeys = @@ -99,7 +99,7 @@ public void testSimpleExpand() { } @Test - public void testSimpleExpandMultiOutput() { + public void testSimpleExpandMultiOutput() throws IOException { Pipeline pipeline = createPipeline(); PCollection inputs = pipeline.apply(Create.of("1", "2", "3")); PCollection> withKeys = @@ -123,7 +123,7 @@ public void testSimpleExpandMultiOutput() { } @Test - public void testCompositeExpand() { + public void testCompositeExpand() throws IOException { PTransform, PCollection> transform = new PTransform<>() { @Override @@ -150,7 +150,7 @@ public PCollection expand(PCollection input) { } @Test - public void testSimpleExpandWithInitialState() throws CoderException { + public void testSimpleExpandWithInitialState() throws IOException { Pipeline pipeline = createPipeline(); PCollection inputs = pipeline.apply(Create.of("3", "4")); PCollection> withKeys = @@ -185,7 +185,7 @@ public void testSimpleExpandWithInitialState() throws CoderException { } @Test - public void testSimpleExpandWithStateStore() throws CoderException { + public void testSimpleExpandWithStateStore() throws IOException { Pipeline pipeline = createPipeline(); Instant now = new Instant(0); PCollection inputs = @@ -227,7 +227,7 @@ public void testSimpleExpandWithStateStore() throws CoderException { } @Test - public void testStateWithElementEarly() throws CoderException { + public void testStateWithElementEarly() throws IOException { Pipeline pipeline = createPipeline(); Instant now = new Instant(0); PCollection inputs = @@ -262,16 +262,16 @@ public void testStateWithElementEarly() throws CoderException { } @Test - public void testBufferedTimestampInject() { + public void testBufferedTimestampInject() throws IOException { testTimestampInject(false); } @Test - public void testBufferedTimestampInjectToMultiOutput() { + public void testBufferedTimestampInjectToMultiOutput() throws IOException { testTimestampInject(true); } - private void testTimestampInject(boolean multiOutput) { + private void testTimestampInject(boolean multiOutput) throws IOException { Pipeline pipeline = createPipeline(); Instant now = new Instant(0); PCollection inputs = diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java index 88ddeeb7f..fcb6fafbc 100644 --- a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/MethodCallUtilsTest.java @@ -21,6 +21,7 @@ import cz.o2.proxima.beam.util.state.MethodCallUtils.MethodInvoker; import cz.o2.proxima.beam.util.state.MethodCallUtils.VoidMethodInvoker; import cz.o2.proxima.core.functional.BiConsumer; +import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.functional.Consumer; import cz.o2.proxima.core.functional.Factory; import cz.o2.proxima.core.functional.UnaryFunction; @@ -31,7 +32,6 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; import net.bytebuddy.ByteBuddy; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.state.StateBinder; @@ -79,7 +79,9 @@ public void testMethodInvokerWithVoid() ByteBuddy buddy = new ByteBuddy(); SumCollect s = new SumCollect(); Method method = s.getClass().getDeclaredMethod("apply", int.class, int.class, Consumer.class); - VoidMethodInvoker invoker = VoidMethodInvoker.of(method, buddy); + List> generated = new ArrayList<>(); + ClassCollector collector = (cls, code) -> generated.add(cls); + VoidMethodInvoker invoker = VoidMethodInvoker.of(method, buddy, collector); List list = new ArrayList<>(); Consumer c = list::add; invoker.invoke(s, new Object[] {1, 2, c}); @@ -92,6 +94,7 @@ public void testMethodInvokerWithVoid() } long duration = System.nanoTime() - start; assertTrue(duration < 1_000_000_000); + assertEquals(1, generated.size()); } @Test @@ -100,13 +103,17 @@ public void testNonStaticSubclass() NoSuchMethodException, InstantiationException, IllegalAccessException { + + List> generated = new ArrayList<>(); + ClassCollector collector = (cls, code) -> generated.add(cls); Sum s = new Sum() { final MethodInvoker invoker = MethodInvoker.of( ExceptionUtils.uncheckedFactory( () -> Delegate.class.getDeclaredMethod("apply", int.class, int.class)), - new ByteBuddy()); + new ByteBuddy(), + collector); class Delegate { Integer apply(int a, int b) { @@ -120,6 +127,7 @@ public Integer apply(int a, int b) { } }; assertEquals(3, (int) s.apply(1, 2)); + assertEquals(1, generated.size()); } void testMethodInvokerWith( @@ -132,7 +140,9 @@ void testMethodInvokerWith( ByteBuddy buddy = new ByteBuddy(); T s = instanceFactory.apply(); Method method = s.getClass().getDeclaredMethod("apply", paramType, paramType); - MethodInvoker invoker = MethodInvoker.of(method, buddy); + List> generated = new ArrayList<>(); + ClassCollector collector = (cls, code) -> generated.add(cls); + MethodInvoker invoker = MethodInvoker.of(method, buddy, collector); assertEquals( valueFactory.apply(3), invoker.invoke( @@ -144,6 +154,7 @@ void testMethodInvokerWith( } long duration = System.nanoTime() - start; assertTrue(duration < 1_000_000_000); + assertEquals(1, generated.size()); } private void testBinder(StateBinder binder) { diff --git a/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java b/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java index 3bd6797a1..82ab8f35c 100644 --- a/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java +++ b/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java @@ -16,6 +16,8 @@ package cz.o2.proxima.beam.tools.groovy; import static cz.o2.proxima.beam.tools.groovy.BeamStream.dehydrate; +import static cz.o2.proxima.beam.util.RunnerUtils.createJarFromDynamicClasses; +import static cz.o2.proxima.beam.util.RunnerUtils.registerToPipeline; import com.google.api.client.util.Lists; import com.google.auto.service.AutoService; @@ -28,6 +30,7 @@ import cz.o2.proxima.core.storage.commitlog.Position; import cz.o2.proxima.core.util.Classpath; import cz.o2.proxima.core.util.ExceptionUtils; +import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.annotations.VisibleForTesting; import cz.o2.proxima.internal.com.google.common.base.MoreObjects; import cz.o2.proxima.internal.com.google.common.base.Preconditions; @@ -36,34 +39,22 @@ import cz.o2.proxima.tools.groovy.ToolsClassLoader; import cz.o2.proxima.tools.groovy.WindowedStream; import groovy.lang.Closure; -import java.io.ByteArrayInputStream; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; -import java.io.InputStream; import java.net.URI; -import java.net.URL; -import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.function.Supplier; -import java.util.jar.JarEntry; -import java.util.jar.JarOutputStream; import java.util.stream.Collectors; -import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.AccessLevel; import lombok.Getter; import lombok.extern.slf4j.Slf4j; -import org.apache.beam.repackaged.core.org.apache.commons.compress.utils.IOUtils; -import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.spark.SparkCommonPipelineOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.options.ExperimentalOptions; @@ -225,17 +216,6 @@ protected Supplier getPipelineOptionsFactory() { return PipelineOptionsFactory::create; } - /** - * List all UDFs created. - * - * @return set of all UDFs - */ - protected Set listUdfClassNames() { - return Optional.ofNullable(getToolsClassLoader()) - .map(ToolsClassLoader::getDefinedClasses) - .orElse(Collections.emptySet()); - } - Factory getJarRegisteringPipelineFactory() { Supplier factory = getPipelineOptionsFactory(); UnaryFunction createPipeline = getCreatePipelineFromOpts(); @@ -272,10 +252,12 @@ void createUdfJarAndRegisterToPipeline(PipelineOptions opts) { String runnerName = opts.getRunner().getSimpleName(); try { File path = createJarFromUdfs(); - log.info("Created jar {} with generated classes.", path); - List files = new ArrayList<>(Collections.singletonList(path)); - getAddedJars().stream().map(u -> new File(u.getPath())).forEach(files::add); - registerToPipeline(opts, runnerName, files); + if (path != null) { + log.info("Created jar {} with generated classes.", path); + List files = new ArrayList<>(Collections.singletonList(path)); + getAddedJars().stream().map(u -> new File(u.getPath())).forEach(files::add); + registerToPipeline(opts, runnerName, files); + } } catch (IOException ex) { throw new RuntimeException(ex); } @@ -287,67 +269,19 @@ private Collection getAddedJars() { .orElse(Collections.emptySet()); } - private void registerToPipeline(PipelineOptions opts, String runnerName, Collection paths) { - log.info("Adding jars {} into classpath for runner {}", paths, runnerName); - List filesToStage = - paths.stream().map(File::getAbsolutePath).collect(Collectors.toList()); - if (runnerName.equals("FlinkRunner")) { - FlinkPipelineOptions flinkOpts = opts.as(FlinkPipelineOptions.class); - flinkOpts.setFilesToStage(addToList(filesToStage, flinkOpts.getFilesToStage())); - } else if (runnerName.equals("SparkRunner")) { - SparkCommonPipelineOptions sparkOpts = opts.as(SparkCommonPipelineOptions.class); - sparkOpts.setFilesToStage(addToList(filesToStage, sparkOpts.getFilesToStage())); - } else { - if (!runnerName.equals("DirectRunner")) { - log.warn( - "Injecting jar into unknown runner {}. It might not work as expected. " - + "If you are experiencing issues with sub run and/or submission, " - + "please fill github issue reporting the name of the runner.", - runnerName); - } - injectJarIntoContextClassLoader(paths); - } - } - - private List addToList(@Nonnull List first, @Nullable List second) { - Collection res = new HashSet<>(first); - if (second != null) { - res.addAll(second); - } - return new ArrayList<>(res); - } - - private File createJarFromUdfs() throws IOException { - Set classes = listUdfClassNames(); - File out = File.createTempFile("proxima-tools", ".jar"); + private @Nullable File createJarFromUdfs() throws IOException { ToolsClassLoader loader = getToolsClassLoader(); - log.info("Building jar from classes {} retrieved from {}", classes, loader); - - out.deleteOnExit(); - try (JarOutputStream output = new JarOutputStream(new FileOutputStream(out))) { - long now = System.currentTimeMillis(); - for (String cls : classes) { - String name = cls.replace('.', '/') + ".class"; - JarEntry entry = new JarEntry(name); - entry.setTime(now); - output.putNextEntry(entry); - InputStream input = new ByteArrayInputStream(loader.getClassByteCode(cls)); - IOUtils.copy(input, output); - output.closeEntry(); - } - } - return out; - } - - @VisibleForTesting - static void injectJarIntoContextClassLoader(Collection paths) { - ClassLoader loader = Thread.currentThread().getContextClassLoader(); - URL[] urls = - paths.stream() - .map(p -> ExceptionUtils.uncheckedFactory(() -> p.toURI().toURL())) - .collect(Collectors.toList()) - .toArray(new URL[] {}); - Thread.currentThread().setContextClassLoader(new URLClassLoader(urls, loader)); + Map, byte[]> codeMap = + Optional.ofNullable(loader) + .map( + l -> + l.getDefinedClasses().stream() + .map(name -> ExceptionUtils.uncheckedFactory(() -> loader.loadClass(name))) + .map(cls -> Pair.of(cls, loader.getClassByteCode(cls.getName()))) + .collect(Collectors.toMap(Pair::getFirst, Pair::getSecond))) + .orElse(Collections.emptyMap()); + log.info("Building jar from classes {} retrieved from {}", codeMap, loader); + return createJarFromDynamicClasses(codeMap); } private @Nullable ToolsClassLoader getToolsClassLoader() { diff --git a/beam/tools/src/test/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProviderTest.java b/beam/tools/src/test/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProviderTest.java index 71c3fddf0..84a628004 100644 --- a/beam/tools/src/test/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProviderTest.java +++ b/beam/tools/src/test/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProviderTest.java @@ -15,6 +15,7 @@ */ package cz.o2.proxima.beam.tools.groovy; +import static cz.o2.proxima.beam.util.RunnerUtils.injectJarIntoContextClassLoader; import static org.junit.Assert.*; import cz.o2.proxima.core.repository.Repository; @@ -82,7 +83,7 @@ public void testRunnerUnknownRunnerJarInject() { @Test public void testInjectPathToClassloader() throws IOException { File f = File.createTempFile("dummy", ".tmp"); - BeamStreamProvider.injectJarIntoContextClassLoader(Collections.singletonList(f)); + injectJarIntoContextClassLoader(Collections.singletonList(f)); ClassLoader loader = Thread.currentThread().getContextClassLoader(); assertTrue(loader instanceof URLClassLoader); }