From b5d126c242d65c6bb097e66beea581bc3ac757ff Mon Sep 17 00:00:00 2001 From: Jozef Vilcek Date: Mon, 27 Jan 2025 15:17:01 +0100 Subject: [PATCH] [spark] Skip unused outputs of ParDo in SparkRunner --- .../spark/DependentTransformsVisitor.java | 52 +++++++ .../beam/runners/spark/SparkRunner.java | 10 +- .../spark/translation/EvaluationContext.java | 21 +++ .../translation/TransformTranslator.java | 45 +++++- .../spark/DependentTransformsVisitorTest.java | 141 ++++++++++++++++++ .../spark/translation/RDDTreeParser.java | 6 +- .../translation/TransformTranslatorTest.java | 45 +++++- 7 files changed, 314 insertions(+), 6 deletions(-) create mode 100644 runners/spark/src/main/java/org/apache/beam/runners/spark/DependentTransformsVisitor.java create mode 100644 runners/spark/src/test/java/org/apache/beam/runners/spark/DependentTransformsVisitorTest.java diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/DependentTransformsVisitor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/DependentTransformsVisitor.java new file mode 100644 index 000000000000..9957655fa60d --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/DependentTransformsVisitor.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark; + +import java.util.Map; +import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; +import org.apache.beam.sdk.runners.TransformHierarchy; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TupleTag; + +/** + * Traverses the pipeline to populate information on how many {@link + * org.apache.beam.sdk.transforms.PTransform}s do consume / depends on each {@link PCollection} in + * the pipeline. + */ +class DependentTransformsVisitor extends SparkRunner.Evaluator { + + DependentTransformsVisitor( + SparkPipelineTranslator translator, EvaluationContext evaluationContext) { + super(translator, evaluationContext); + } + + @Override + public void doVisitTransform(TransformHierarchy.Node node) { + + Map, Integer> dependentTransforms = ctxt.getDependentTransforms(); + for (Map.Entry, PCollection> entry : node.getInputs().entrySet()) { + int dependants = dependentTransforms.getOrDefault(entry.getValue(), 0); + dependentTransforms.put(entry.getValue(), dependants + 1); + } + + for (PCollection pOut : node.getOutputs().values()) { + dependentTransforms.computeIfAbsent(pOut, k -> 0); + } + } +} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index 2b72ffb0f225..04a92eac859c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -214,6 +214,7 @@ public SparkPipelineResult run(final Pipeline pipeline) { // update the cache candidates updateCacheCandidates(pipeline, translator, evaluationContext); + updateDependentTransforms(pipeline, translator, evaluationContext); // update GBK candidates for memory optimized transform pipeline.traverseTopologically(new GroupByKeyVisitor(translator, evaluationContext)); @@ -275,8 +276,13 @@ static void detectTranslationMode(Pipeline pipeline, SparkPipelineOptions pipeli /** Evaluator that update/populate the cache candidates. */ public static void updateCacheCandidates( Pipeline pipeline, SparkPipelineTranslator translator, EvaluationContext evaluationContext) { - CacheVisitor cacheVisitor = new CacheVisitor(translator, evaluationContext); - pipeline.traverseTopologically(cacheVisitor); + pipeline.traverseTopologically(new CacheVisitor(translator, evaluationContext)); + } + + /** Evaluator that update/populate information about dependent transforms for pCollections. */ + public static void updateDependentTransforms( + Pipeline pipeline, SparkPipelineTranslator translator, EvaluationContext evaluationContext) { + pipeline.traverseTopologically(new DependentTransformsVisitor(translator, evaluationContext)); } /** The translation mode of the Beam Pipeline. */ diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 1f44a29002ef..50c5214d9c26 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -61,6 +61,7 @@ public class EvaluationContext { private final Map datasets = new LinkedHashMap<>(); private final Map pcollections = new LinkedHashMap<>(); private final Set leaves = new LinkedHashSet<>(); + private final Map, Integer> dependentTransforms = new HashMap<>(); private final Map pobjects = new LinkedHashMap<>(); private AppliedPTransform currentTransform; private final SparkPCollectionView pviews = new SparkPCollectionView(); @@ -307,6 +308,26 @@ public boolean isCandidateForGroupByKeyAndWindow(GroupByKey transfo return groupByKeyCandidatesForMemoryOptimizedTranslation.containsKey(transform); } + /** + * Get the map of dependent transforms hold by the evaluation context. + * + * @return The current {@link Map} of dependent transforms. + */ + public Map, Integer> getDependentTransforms() { + return this.dependentTransforms; + } + + /** + * Get if given {@link PCollection} is a leaf or not. {@link PCollection} is a leaf when there is + * no other {@link PTransform} consuming it / depending on it. + * + * @param pCollection to be checked if it is a leaf + * @return true if pCollection is leaf; otherwise false + */ + public boolean isLeaf(PCollection pCollection) { + return this.dependentTransforms.get(pCollection) == 0; + } + Iterable> getWindowedValues(PCollection pcollection) { @SuppressWarnings("unchecked") BoundedDataset boundedDataset = (BoundedDataset) datasets.get(pcollection); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index ebfcecf030b7..6ec5c6557fc6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.spark.translation; import static org.apache.beam.runners.spark.translation.TranslationUtils.canAvoidRddSerialization; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import java.util.Arrays; @@ -70,6 +71,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; @@ -77,6 +79,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.spark.HashPartitioner; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; @@ -428,13 +431,14 @@ public void evaluate( Map> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + TupleTag mainOutputTag = transform.getMainOutputTag(); MultiDoFnFunction multiDoFnFunction = new MultiDoFnFunction<>( metricsAccum, stepName, doFn, context.getSerializableOptions(), - transform.getMainOutputTag(), + mainOutputTag, transform.getAdditionalOutputTags().getAll(), inputCoder, outputCoders, @@ -460,7 +464,13 @@ public void evaluate( all = inRDD.mapPartitionsToPair(multiDoFnFunction); } - Map, PCollection> outputs = context.getOutputs(transform); + // Filter out obsolete PCollections to only cache when absolutely necessary + Map, PCollection> outputs = + skipObsoleteOutputs( + context.getOutputs(transform), + mainOutputTag, + transform.getAdditionalOutputTags(), + context); if (hasMultipleOutputs(outputs)) { StorageLevel level = StorageLevel.fromString(context.storageLevel()); if (canAvoidRddSerialization(level)) { @@ -498,6 +508,37 @@ private boolean hasMultipleOutputs(Map, PCollection> outputs) { return outputs.size() > 1; } + /** + * Filter out obsolete, unused output tags except for {@code mainTag}. + * + *

This can help to avoid unnecessary caching in case of multiple outputs if only {@code + * mainTag} is consumed. + */ + private Map, PCollection> skipObsoleteOutputs( + Map, PCollection> outputs, + TupleTag mainTag, + TupleTagList otherTags, + EvaluationContext cxt) { + switch (outputs.size()) { + case 1: + return outputs; // always keep main output + case 2: + TupleTag otherTag = otherTags.get(0); + return cxt.isLeaf(checkStateNotNull(outputs.get(otherTag))) + ? Collections.singletonMap(mainTag, checkStateNotNull(outputs.get(mainTag))) + : outputs; + default: + Map, PCollection> filtered = + Maps.newHashMapWithExpectedSize(outputs.size()); + for (Map.Entry, PCollection> e : outputs.entrySet()) { + if (e.getKey().equals(mainTag) || !cxt.isLeaf(e.getValue())) { + filtered.put(e.getKey(), e.getValue()); + } + } + return filtered; + } + } + @Override public String toNativeString() { return "mapPartitions(new ())"; diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/DependentTransformsVisitorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/DependentTransformsVisitorTest.java new file mode 100644 index 000000000000..3b2740bcb5b3 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/DependentTransformsVisitorTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark; + +import static org.junit.Assert.assertEquals; + +import java.util.List; +import java.util.Objects; +import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.TransformTranslator; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +/** Tests of {@link DependentTransformsVisitor}. */ +public class DependentTransformsVisitorTest { + + @ClassRule public static SparkContextRule contextRule = new SparkContextRule(); + + @Rule public TemporaryFolder tmpFolder = new TemporaryFolder(); + + @Test + public void testCountDependentTransformsOnApplyAndSideInputs() { + SparkPipelineOptions options = contextRule.createPipelineOptions(); + Pipeline pipeline = Pipeline.create(options); + PCollection pCollection = pipeline.apply(Create.of("foo", "bar")); + + // First use of pCollection. + PCollection leaf1 = pCollection.apply(Count.globally()); + // Second use of pCollection. + PCollectionView> view = pCollection.apply("yyy", View.asList()); + + PCollection leaf2 = + pipeline + .apply(Create.of("foo", "baz")) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext processContext) { + if (processContext.sideInput(view).contains(processContext.element())) { + processContext.output(processContext.element()); + } + } + }) + .withSideInputs(view)); + + EvaluationContext ctxt = + new EvaluationContext(contextRule.getSparkContext(), pipeline, options); + TransformTranslator.Translator translator = new TransformTranslator.Translator(); + pipeline.traverseTopologically(new DependentTransformsVisitor(translator, ctxt)); + + assertEquals(2, ctxt.getDependentTransforms().get(pCollection).intValue()); + assertEquals(0, ctxt.getDependentTransforms().get(leaf1).intValue()); + assertEquals(0, ctxt.getDependentTransforms().get(leaf2).intValue()); + assertEquals(2, ctxt.getDependentTransforms().get(view.getPCollection()).intValue()); + } + + @Test + public void testCountDependentTransformsOnSideOutputs() { + SparkPipelineOptions options = contextRule.createPipelineOptions(); + Pipeline pipeline = Pipeline.create(options); + + TupleTag passOutTag = new TupleTag<>("passOut"); + TupleTag lettersCountOutTag = new TupleTag<>("lettersOut"); + TupleTag wordCountOutTag = new TupleTag<>("wordsOut"); + + PCollectionTuple result = + pipeline + .apply(Create.of("foo", "baz")) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext processContext) { + String element = processContext.element(); + processContext.output(element); + processContext.output( + lettersCountOutTag, + (long) Objects.requireNonNull(element).length()); + processContext.output(wordCountOutTag, 1L); + } + }) + .withOutputTags( + passOutTag, + TupleTagList.of(Lists.newArrayList(lettersCountOutTag, wordCountOutTag)))); + + // consume main output and words side output. leave letters side output left alone + result.get(wordCountOutTag).setCoder(VarLongCoder.of()).apply(Sum.longsGlobally()); + result.get(lettersCountOutTag).setCoder(VarLongCoder.of()); + result + .get(passOutTag) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext processContext) { + // do nothing + } + })); + + EvaluationContext ctxt = + new EvaluationContext(contextRule.getSparkContext(), pipeline, options); + TransformTranslator.Translator translator = new TransformTranslator.Translator(); + pipeline.traverseTopologically(new DependentTransformsVisitor(translator, ctxt)); + + assertEquals(1, ctxt.getDependentTransforms().get(result.get(passOutTag)).intValue()); + assertEquals(1, ctxt.getDependentTransforms().get(result.get(wordCountOutTag)).intValue()); + assertEquals(0, ctxt.getDependentTransforms().get(result.get(lettersCountOutTag)).intValue()); + } +} diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/RDDTreeParser.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/RDDTreeParser.java index 26419aff9f97..215cdbed5a93 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/RDDTreeParser.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/RDDTreeParser.java @@ -30,7 +30,7 @@ public static List parse(String debugString) { for (String line : lines) { line = line.trim(); - if (line.isEmpty()) { + if (line.isEmpty() || isStatsLine(line)) { continue; } @@ -48,6 +48,10 @@ public static List parse(String debugString) { return list; } + private static boolean isStatsLine(String line) { + return line.contains("MemorySize:") && line.contains("DiskSize:"); + } + private static int extractId(String line) { String idPart = line.substring(line.indexOf('[') + 1, line.indexOf(']')); return Integer.parseInt(idPart); diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java index 2f84c2b23fac..2faf38a106df 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java @@ -20,6 +20,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import java.io.Serializable; @@ -37,6 +38,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -165,7 +167,7 @@ public void testSingleOutputParDoHasNoFilter() { } @Test - public void testMultipleOutputPardoHaveFilter() { + public void testMultipleOutputParDoShouldNotHaveFilterWhenSideOutputIsNotConsumed() { Pipeline p = Pipeline.create(); TupleTag tag1 = new TupleTag("tag1") {}; TupleTag tag2 = new TupleTag("tag2") {}; @@ -186,6 +188,47 @@ public void testMultipleOutputPardoHaveFilter() { EvaluationContext ctxt = new EvaluationContext(contextRule.getSparkContext(), p, options); SparkRunner.initAccumulators(options, ctxt.getSparkContext()); + SparkRunner.updateDependentTransforms(p, translator, ctxt); + + p.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt)); + + // check main output for filter + @SuppressWarnings("unchecked") + BoundedDataset dataset = + (BoundedDataset) ctxt.borrowDataset(pCollectionTuple.get(tag1)); + List parsed = RDDTreeParser.parse(dataset.getRDD().toDebugString()); + assertThat(parsed.stream().map(RDDNode::getOperator)).doesNotContain("filter"); + + // check that second tag is not present + assertNull(ctxt.borrowDataset(pCollectionTuple.get(tag2))); + } + + @Test + public void testMultipleOutputParDoShouldHaveFilterWhenSideOutputIsConsumed() { + Pipeline p = Pipeline.create(); + TupleTag tag1 = new TupleTag("tag1") {}; + TupleTag tag2 = new TupleTag("tag2") {}; + + SparkPipelineOptions options = contextRule.createPipelineOptions(); + TransformTranslator.Translator translator = new TransformTranslator.Translator(); + + PTransform> createTransform = Create.of("foo", "bar"); + + PassThrough.MultipleOutput passThroughTransform = + PassThrough.ofMultipleOutput(tag1, tag2); + + PCollectionTuple pCollectionTuple = + p.apply("Create Values", createTransform) + .apply("Multiple Output PassThrough", passThroughTransform); + + // consume side output + pCollectionTuple.get(tag2).apply(Count.globally()); + + p.replaceAll(SparkTransformOverrides.getDefaultOverrides(false)); + + EvaluationContext ctxt = new EvaluationContext(contextRule.getSparkContext(), p, options); + SparkRunner.initAccumulators(options, ctxt.getSparkContext()); + SparkRunner.updateDependentTransforms(p, translator, ctxt); p.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt));