diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java new file mode 100644 index 0000000000000..b0aed4d08d387 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.io.Serializable; + +/** + * Represents metadata about where shuffle blocks were written in a single map task. + *

+ * This is optionally returned by shuffle writers. The inner shuffle locations may + * be accessed by shuffle readers. Shuffle locations are only necessary when the + * location of shuffle blocks needs to be managed by the driver; shuffle plugins + * may choose to use an external database or other metadata management systems to + * track the locations of shuffle blocks instead. + */ +@Experimental +public interface MapShuffleLocations extends Serializable { + + /** + * Get the location for a given shuffle block written by this map task. + */ + ShuffleLocation getLocationForBlock(int reduceId); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java new file mode 100644 index 0000000000000..a312831cb6282 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.api.java.Optional; + +import java.util.Objects; + +/** + * :: Experimental :: + * An object defining the shuffle block and length metadata associated with the block. + * @since 3.0.0 + */ +public class ShuffleBlockInfo { + private final int shuffleId; + private final int mapId; + private final int reduceId; + private final long length; + private final Optional shuffleLocation; + + public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, + Optional shuffleLocation) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.length = length; + this.shuffleLocation = shuffleLocation; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getMapId() { + return mapId; + } + + public int getReduceId() { + return reduceId; + } + + public long getLength() { + return length; + } + + public Optional getShuffleLocation() { + return shuffleLocation; + } + + @Override + public boolean equals(Object other) { + return other instanceof ShuffleBlockInfo + && shuffleId == ((ShuffleBlockInfo) other).shuffleId + && mapId == ((ShuffleBlockInfo) other).mapId + && reduceId == ((ShuffleBlockInfo) other).reduceId + && length == ((ShuffleBlockInfo) other).length + && Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation); + } + + @Override + public int hashCode() { + return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation); + } +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java new file mode 100644 index 0000000000000..dd7c0ac7320cb --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for launching Shuffle related components + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleDataIO { + String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; + + ShuffleDriverComponents driver(); + ShuffleExecutorComponents executor(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java new file mode 100644 index 0000000000000..6a0ec8d44fd4f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; +import java.util.Map; + +public interface ShuffleDriverComponents { + + /** + * @return additional SparkConf values necessary for the executors. + */ + Map initializeApplication(); + + void cleanupApplication() throws IOException; + + void removeShuffleData(int shuffleId, boolean blocking) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java new file mode 100644 index 0000000000000..a5fa032bf651d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.util.Map; + +/** + * :: Experimental :: + * An interface for building shuffle support for Executors + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleExecutorComponents { + void initializeExecutor(String appId, String execId, Map extraConfigs); + + ShuffleWriteSupport writes(); + + ShuffleReadSupport reads(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java new file mode 100644 index 0000000000000..d06c11b3c01ee --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +/** + * Marker interface representing a location of a shuffle block. Implementations of shuffle readers + * and writers are expected to cast this down to an implementation-specific representation. + */ +public interface ShuffleLocation {} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..062cf4ff0fba9 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.Optional; + +/** + * :: Experimental :: + * An interface for creating and managing shuffle partition writers + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleMapOutputWriter { + ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException; + + Optional commitAllPartitions() throws IOException; + + void abort(Throwable error) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java new file mode 100644 index 0000000000000..74c928b0b9c8f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for giving streams / channels for shuffle writes. + * + * @since 3.0.0 + */ +@Experimental +public interface ShufflePartitionWriter { + + /** + * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying + * data store. + */ + OutputStream openStream() throws IOException; + + /** + * Get the number of bytes written by this writer's stream returned by {@link #openStream()}. + */ + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java new file mode 100644 index 0000000000000..9cd8fde09064b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.io.IOException; +import java.io.InputStream; + +/** + * :: Experimental :: + * An interface for reading shuffle records. + * @since 3.0.0 + */ +@Experimental +public interface ShuffleReadSupport { + /** + * Returns an underlying {@link Iterable} that will iterate + * through shuffle data, given an iterable for the shuffle blocks to fetch. + */ + Iterable getPartitionReaders(Iterable blockMetadata) + throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java new file mode 100644 index 0000000000000..7e2b6cf4133fd --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * An interface for deploying a shuffle map output writer + * + * @since 3.0.0 + */ +@Experimental +public interface ShuffleWriteSupport { + ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java new file mode 100644 index 0000000000000..866b61d0bafd9 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * Indicates that partition writers can transfer bytes directly from input byte channels to + * output channels that stream data to the underlying shuffle partition storage medium. + *

+ * This API is separated out for advanced users because it only needs to be used for + * specific low-level optimizations. The idea is that the returned channel can transfer bytes + * from the input file channel out to the backing storage system without copying data into + * memory. + *

+ * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead. + * + * @since 3.0.0 + */ +@Experimental +public interface SupportsTransferTo extends ShufflePartitionWriter { + + /** + * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from + * input byte channels to the underlying shuffle data store. + */ + TransferrableWritableByteChannel openTransferrableChannel() throws IOException; + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}. + */ + @Override + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java new file mode 100644 index 0000000000000..18234d7c4c944 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.Closeable; +import java.io.IOException; + +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * Represents an output byte channel that can copy bytes from input file channels to some + * arbitrary storage system. + *

+ * This API is provided for advanced users who can transfer bytes from a file channel to + * some output sink without copying data into memory. Most users should not need to use + * this functionality; this is primarily provided for the built-in shuffle storage backends + * that persist shuffle files on local disk. + *

+ * For a simpler alternative, see {@link ShufflePartitionWriter}. + * + * @since 3.0.0 + */ +@Experimental +public interface TransferrableWritableByteChannel extends Closeable { + + /** + * Copy all bytes from the source readable byte channel into this byte channel. + * + * @param source File to transfer bytes from. Do not call anything on this channel other than + * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. + * @param transferStartPosition Start position of the input file to transfer from. + * @param numBytesToTransfer Number of bytes to transfer from the given source. + */ + void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer) + throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 32b446785a9f0..128b90429209e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -19,8 +19,10 @@ import java.io.File; import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; import javax.annotation.Nullable; import scala.None$; @@ -34,16 +36,22 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.internal.config.package$; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.SupportsTransferTo; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -82,7 +90,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int shuffleId; private final int mapId; private final Serializer serializer; - private final IndexShuffleBlockResolver shuffleBlockResolver; + private final ShuffleWriteSupport shuffleWriteSupport; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -99,11 +107,11 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, SparkConf conf, - ShuffleWriteMetricsReporter writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleWriteSupport shuffleWriteSupport) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -115,58 +123,67 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleBlockResolver = shuffleBlockResolver; + this.shuffleWriteSupport = shuffleWriteSupport; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); - return; - } - final SerializerInstance serInstance = serializer.newInstance(); - final long openStartTime = System.nanoTime(); - partitionWriters = new DiskBlockObjectWriter[numPartitions]; - partitionWriterSegments = new FileSegment[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - final Tuple2 tempShuffleBlockIdPlusFile = - blockManager.diskBlockManager().createTempShuffleBlock(); - final File file = tempShuffleBlockIdPlusFile._2(); - final BlockId blockId = tempShuffleBlockIdPlusFile._1(); - partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); - } + ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport + .createMapOutputWriter(shuffleId, mapId, numPartitions); + try { + if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + Optional blockLocs = mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + blockLocs.orNull(), + partitionLengths); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - for (int i = 0; i < numPartitions; i++) { - try (DiskBlockObjectWriter writer = partitionWriters[i]) { - partitionWriterSegments[i] = writer.commitAndGet(); + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - } - File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - File tmp = Utils.tempFileWith(output); - try { - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + for (int i = 0; i < numPartitions; i++) { + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } } + + partitionLengths = writePartitionedData(mapOutputWriter); + Optional mapLocations = mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + mapLocations.orNull(), + partitionLengths); + } catch (Exception e) { + try { + mapOutputWriter.abort(e); + } catch (Exception e2) { + logger.error("Failed to abort the writer after failing to write map output.", e2); + } + throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting @@ -179,37 +196,57 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedFile(File outputFile) throws IOException { + private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator return lengths; } - - final FileOutputStream out = new FileOutputStream(outputFile, true); final long writeStartTime = System.nanoTime(); - boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); if (file.exists()) { - final FileInputStream in = new FileInputStream(file); boolean copyThrewException = true; - try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); + if (transferToEnabled) { + FileInputStream in = new FileInputStream(file); + TransferrableWritableByteChannel outputChannel = null; + try (FileChannel inputChannel = in.getChannel()) { + if (writer instanceof SupportsTransferTo) { + outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel(); + } else { + // Use default transferrable writable channel anyways in order to have parity with + // UnsafeShuffleWriter. + outputChannel = new DefaultTransferrableWritableByteChannel( + Channels.newChannel(writer.openStream())); + } + outputChannel.transferFrom(inputChannel, 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + Closeables.close(outputChannel, copyThrewException); + } + } else { + FileInputStream in = new FileInputStream(file); + OutputStream outputStream = null; + try { + outputStream = writer.openStream(); + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + Closeables.close(outputStream, copyThrewException); + } } if (!file.delete()) { logger.error("Unable to delete file for partition {}", i); } } + lengths[i] = writer.getNumBytesWritten(); } - threwException = false; } finally { - Closeables.close(out, threwException); writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java new file mode 100644 index 0000000000000..ffd97c0f26605 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; + +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.ShuffleLocation; +import org.apache.spark.storage.BlockManagerId; + +import java.util.Objects; + +public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation { + + /** + * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be + * feasible. + */ + private static final LoadingCache + DEFAULT_SHUFFLE_LOCATIONS_CACHE = + CacheBuilder.newBuilder() + .maximumSize(BlockManagerId.blockManagerIdCacheSize()) + .build(new CacheLoader() { + @Override + public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) { + return new DefaultMapShuffleLocations(blockManagerId); + } + }); + + private final BlockManagerId location; + + public DefaultMapShuffleLocations(BlockManagerId blockManagerId) { + this.location = blockManagerId; + } + + public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) { + return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId); + } + + @Override + public ShuffleLocation getLocationForBlock(int reduceId) { + return this; + } + + public BlockManagerId getBlockManagerId() { + return location; + } + + @Override + public boolean equals(Object other) { + return other instanceof DefaultMapShuffleLocations + && Objects.equals(((DefaultMapShuffleLocations) other).location, location); + } + + @Override + public int hashCode() { + return Objects.hashCode(location); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java new file mode 100644 index 0000000000000..64ce851e392d2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.util.Utils; + +/** + * This is used when transferTo is enabled but the shuffle plugin hasn't implemented + * {@link org.apache.spark.api.shuffle.SupportsTransferTo}. + *

+ * This default implementation exists as a convenience to the unsafe shuffle writer and + * the bypass merge sort shuffle writers. + */ +public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel { + + private final WritableByteChannel delegate; + + public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) { + this.delegate = delegate; + } + + @Override + public void transferFrom( + FileChannel source, long transferStartPosition, long numBytesToTransfer) { + Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9d05f03613ce9..5dd0821e10f59 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.sort; +import java.nio.channels.Channels; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; @@ -31,18 +32,22 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; -import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; -import org.apache.commons.io.output.CloseShieldOutputStream; -import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -50,12 +55,9 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -65,15 +67,14 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @VisibleForTesting - static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; - private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; + private final ShuffleWriteSupport shuffleWriteSupport; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -81,7 +82,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private final int initialSortBufferSize; private final int inputBufferSizeInBytes; - private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -103,27 +103,15 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; - private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { - - CloseAndFlushShieldOutputStream(OutputStream outputStream) { - super(outputStream); - } - - @Override - public void flush() { - // do nothing - } - } - public UnsafeShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf, - ShuffleWriteMetricsReporter writeMetrics) throws IOException { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleWriteSupport shuffleWriteSupport) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -132,7 +120,6 @@ public UnsafeShuffleWriter( " reduce partitions"); } this.blockManager = blockManager; - this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); @@ -140,6 +127,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; + this.shuffleWriteSupport = shuffleWriteSupport; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -147,8 +135,6 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.inputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; - this.outputBufferSizeInBytes = - (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -230,26 +216,33 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; + final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - final File tmp = Utils.tempFileWith(output); + Optional mapLocations; try { try { - partitionLengths = mergeSpills(spills, tmp); + partitionLengths = mergeSpills(spills, mapWriter); } finally { for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { + if (spill.file.exists() && !spill.file.delete()) { logger.error("Error while deleting spill file {}", spill.file.getPath()); } } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + mapLocations = mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception innerE) { + logger.error("Failed to abort the Map Output Writer", innerE); } + throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + mapLocations.orNull(), + partitionLengths); } @VisibleForTesting @@ -281,7 +274,8 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private long[] mergeSpills(SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = @@ -289,17 +283,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + final int numPartitions = partitioner.numPartitions(); + long[] partitionLengths = new long[numPartitions]; try { if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; - } else if (spills.length == 1) { - // Here, we don't need to perform any metrics updates because the bytes written to this - // output file would have already been counted as shuffle bytes written. - Files.move(spills[0].file, outputFile); - return spills[0].partitionLengths; + return partitionLengths; } else { - final long[] partitionLengths; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -316,14 +305,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -331,13 +320,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incBytesWritten(outputFile.length()); return partitionLengths; } } catch (IOException e) { - if (outputFile.exists() && !outputFile.delete()) { - logger.error("Unable to delete output file {}", outputFile.getPath()); - } throw e; } } @@ -345,73 +330,71 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], - * File)}, and it's mostly used in cases where the IO compression codec does not support - * concatenation of compressed data, when encryption is enabled, or when users have - * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec + * does not support concatenation of compressed data, when encryption is enabled, or when + * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs. * This code path might also be faster in cases where individual partition size in a spill * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small * disk ios which is inefficient. In those case, Using large buffers for input and output * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. - * @param outputFile the file to write the merged data to. + * @param mapWriter the map output writer to use for output. * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ private long[] mergeSpillsWithFileStream( SpillInfo[] spills, - File outputFile, + ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { - assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; - final OutputStream bos = new BufferedOutputStream( - new FileOutputStream(outputFile), - outputBufferSizeInBytes); - // Use a counting output stream to avoid having to close the underlying file and ask - // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close - // the higher level streams to make sure all data is really flushed and internal state is - // cleaned. - OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( - new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - try { - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + boolean copyThrewExecption = true; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + OutputStream partitionOutput = null; + try { + partitionOutput = writer.openStream(); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + try { + partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream( + partitionInputStream); + } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); } + copyThrewExecption = false; } + } finally { + Closeables.close(partitionOutput, copyThrewExecption); } - partitionOutput.flush(); - partitionOutput.close(); - partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); + long numBytesWritten = writer.getNumBytesWritten(); + partitionLengths[partition] = numBytesWritten; + writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; } finally { @@ -420,7 +403,6 @@ private long[] mergeSpillsWithFileStream( for (InputStream stream : spillInputStreams) { Closeables.close(stream, threwException); } - Closeables.close(mergedFileOutputStream, threwException); } return partitionLengths; } @@ -430,54 +412,49 @@ private long[] mergeSpillsWithFileStream( * This is only safe when the IO compression codec and serializer support concatenation of * serialized streams. * + * @param spills the spills to merge. + * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { - assert (spills.length >= 2); + private long[] mergeSpillsWithTransferTo( + SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; - FileChannel mergedFileOutputChannel = null; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } - // This file needs to opened in append mode in order to work around a Linux kernel bug that - // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); - - long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileChannel spillInputChannel = spillInputChannels[i]; - final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - mergedFileOutputChannel, - spillInputChannelPositions[i], - partitionLengthInSpill); - spillInputChannelPositions[i] += partitionLengthInSpill; - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - bytesWrittenToMergedFile += partitionLengthInSpill; - partitionLengths[partition] += partitionLengthInSpill; + boolean copyThrewExecption = true; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + TransferrableWritableByteChannel partitionChannel = null; + try { + partitionChannel = writer instanceof SupportsTransferTo ? + ((SupportsTransferTo) writer).openTransferrableChannel() + : new DefaultTransferrableWritableByteChannel( + Channels.newChannel(writer.openStream())); + for (int i = 0; i < spills.length; i++) { + long partitionLengthInSpill = 0L; + partitionLengthInSpill += spills[i].partitionLengths[partition]; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + partitionChannel.transferFrom( + spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill); + spillInputChannelPositions[i] += partitionLengthInSpill; + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + copyThrewExecption = false; + } finally { + Closeables.close(partitionChannel, copyThrewExecption); } - } - // Check the position after transferTo loop to see if it is in the right position and raise an - // exception if it is incorrect. The position will not be increased to the expected length - // after calling transferTo in kernel version 2.6.32. This issue is described at - // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { - throw new IOException( - "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + - " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + - "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + - "to disable this NIO feature." - ); + long numBytes = writer.getNumBytesWritten(); + partitionLengths[partition] = numBytes; + writeMetrics.incBytesWritten(numBytes); } threwException = false; } finally { @@ -487,7 +464,6 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th assert(spillInputChannelPositions[i] == spills[i].file.length()); Closeables.close(spillInputChannels[i], threwException); } - Closeables.close(mergedFileOutputChannel, threwException); } return partitionLengths; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java new file mode 100644 index 0000000000000..7c124c1fe68bc --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleDataIO; +import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; + +public class DefaultShuffleDataIO implements ShuffleDataIO { + + private final SparkConf sparkConf; + + public DefaultShuffleDataIO(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public ShuffleExecutorComponents executor() { + return new DefaultShuffleExecutorComponents(sparkConf); + } + + @Override + public ShuffleDriverComponents driver() { + return new DefaultShuffleDriverComponents(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java new file mode 100644 index 0000000000000..3b5f9670d64d2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.MapOutputTracker; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleReadSupport; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport; +import org.apache.spark.storage.BlockManager; + +import java.util.Map; + +public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + private MapOutputTracker mapOutputTracker; + private SerializerManager serializerManager; + + public DefaultShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public void initializeExecutor(String appId, String execId, Map extraConfigs) { + blockManager = SparkEnv.get().blockManager(); + mapOutputTracker = SparkEnv.get().mapOutputTracker(); + serializerManager = SparkEnv.get().serializerManager(); + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + } + + @Override + public ShuffleWriteSupport writes() { + checkInitialized(); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); + } + + @Override + public ShuffleReadSupport reads() { + checkInitialized(); + return new DefaultShuffleReadSupport(blockManager, + mapOutputTracker, + serializerManager, + sparkConf); + } + + private void checkInitialized() { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..e83db4e4bcef6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.SupportsTransferTo; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; +import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.util.Utils; + +public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private static final Logger log = + LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class); + + private final int shuffleId; + private final int mapId; + private final ShuffleWriteMetricsReporter metrics; + private final IndexShuffleBlockResolver blockResolver; + private final long[] partitionLengths; + private final int bufferSize; + private int lastPartitionId = -1; + private long currChannelPosition; + private final BlockManagerId shuffleServerId; + + private final File outputFile; + private File outputTempFile; + private FileOutputStream outputFileStream; + private FileChannel outputFileChannel; + private TimeTrackingOutputStream ts; + private BufferedOutputStream outputBufferedFileStream; + + public DefaultShuffleMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions, + BlockManagerId shuffleServerId, + ShuffleWriteMetricsReporter metrics, + IndexShuffleBlockResolver blockResolver, + SparkConf sparkConf) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.shuffleServerId = shuffleServerId; + this.metrics = metrics; + this.blockResolver = blockResolver; + this.bufferSize = + (int) (long) sparkConf.get( + package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.partitionLengths = new long[numPartitions]; + this.outputFile = blockResolver.getDataFile(shuffleId, mapId); + this.outputTempFile = null; + } + + @Override + public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException { + if (partitionId <= lastPartitionId) { + throw new IllegalArgumentException("Partitions should be requested in increasing order."); + } + lastPartitionId = partitionId; + if (outputTempFile == null) { + outputTempFile = Utils.tempFileWith(outputFile); + } + if (outputFileChannel != null) { + currChannelPosition = outputFileChannel.position(); + } else { + currChannelPosition = 0L; + } + return new DefaultShufflePartitionWriter(partitionId); + } + + @Override + public Optional commitAllPartitions() throws IOException { + cleanUp(); + File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + return Optional.of(DefaultMapShuffleLocations.get(shuffleServerId)); + } + + @Override + public void abort(Throwable error) { + try { + cleanUp(); + } catch (Exception e) { + log.error("Unable to close appropriate underlying file stream", e); + } + if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { + log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); + } + } + + private void cleanUp() throws IOException { + if (outputBufferedFileStream != null) { + outputBufferedFileStream.close(); + } + if (outputFileChannel != null) { + outputFileChannel.close(); + } + if (outputFileStream != null) { + outputFileStream.close(); + } + } + + private void initStream() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + ts = new TimeTrackingOutputStream(metrics, outputFileStream); + } + if (outputBufferedFileStream == null) { + outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize); + } + } + + private void initChannel() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputFileChannel == null) { + outputFileChannel = outputFileStream.getChannel(); + } + } + + private class DefaultShufflePartitionWriter implements SupportsTransferTo { + + private final int partitionId; + private PartitionWriterStream partStream = null; + private PartitionWriterChannel partChannel = null; + + private DefaultShufflePartitionWriter(int partitionId) { + this.partitionId = partitionId; + } + + @Override + public OutputStream openStream() throws IOException { + if (partStream == null) { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + partStream = new PartitionWriterStream(partitionId); + } + return partStream; + } + + @Override + public TransferrableWritableByteChannel openTransferrableChannel() throws IOException { + if (partChannel == null) { + if (partStream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + partChannel = new PartitionWriterChannel(partitionId); + } + return partChannel; + } + + @Override + public long getNumBytesWritten() { + if (partChannel != null) { + try { + return partChannel.getCount(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else if (partStream != null) { + return partStream.getCount(); + } else { + // Assume an empty partition if stream and channel are never created + return 0; + } + } + } + + private class PartitionWriterStream extends OutputStream { + private final int partitionId; + private int count = 0; + private boolean isClosed = false; + + PartitionWriterStream(int partitionId) { + this.partitionId = partitionId; + } + + public int getCount() { + return count; + } + + @Override + public void write(int b) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(b); + count++; + } + + @Override + public void write(byte[] buf, int pos, int length) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(buf, pos, length); + count += length; + } + + @Override + public void close() { + isClosed = true; + partitionLengths[partitionId] = count; + } + + private void verifyNotClosed() { + if (isClosed) { + throw new IllegalStateException("Attempting to write to a closed block output stream."); + } + } + } + + private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel { + + private final int partitionId; + + PartitionWriterChannel(int partitionId) { + super(outputFileChannel); + this.partitionId = partitionId; + } + + public long getCount() throws IOException { + long writtenPosition = outputFileChannel.position(); + return writtenPosition - currChannelPosition; + } + + @Override + public void close() throws IOException { + partitionLengths[partitionId] = getCount(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java new file mode 100644 index 0000000000000..86f1583495689 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManagerId; + +public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { + + private final SparkConf sparkConf; + private final IndexShuffleBlockResolver blockResolver; + private final BlockManagerId shuffleServerId; + + public DefaultShuffleWriteSupport( + SparkConf sparkConf, + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { + this.sparkConf = sparkConf; + this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions) { + return new DefaultShuffleMapOutputWriter( + shuffleId, mapId, numPartitions, shuffleServerId, + TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java new file mode 100644 index 0000000000000..a3eddc8ec930e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.lifecycle; + +import com.google.common.collect.ImmutableMap; +import org.apache.spark.SparkEnv; +import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.storage.BlockManagerMaster; + +import java.io.IOException; +import java.util.Map; + +public class DefaultShuffleDriverComponents implements ShuffleDriverComponents { + + private BlockManagerMaster blockManagerMaster; + + @Override + public Map initializeApplication() { + blockManagerMaster = SparkEnv.get().blockManager().master(); + return ImmutableMap.of(); + } + + @Override + public void cleanupApplication() { + // do nothing + } + + @Override + public void removeShuffleData(int shuffleId, boolean blocking) throws IOException { + checkInitialized(); + blockManagerMaster.removeShuffle(shuffleId, blocking); + } + + private void checkInitialized() { + if (blockManagerMaster == null) { + throw new IllegalStateException("Driver components must be initialized before using"); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index a111a60d1d024..162216bd0c5a0 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled import scala.collection.JavaConverters._ +import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -58,7 +59,9 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner(sc: SparkContext) extends Logging { +private[spark] class ContextCleaner( + sc: SparkContext, + shuffleDriverComponents: ShuffleDriverComponents) extends Logging { /** * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they @@ -221,7 +224,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId, blocking) + shuffleDriverComponents.removeShuffleData(shuffleId, blocking) listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { @@ -269,7 +272,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c6271d251970c..b7462c350796a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal +import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation} import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -281,9 +282,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } // For testing - def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } /** @@ -295,8 +296,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -645,8 +646,8 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -682,12 +683,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private val fetching = new HashSet[Int] // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -871,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -883,8 +885,14 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + if (status.mapShuffleLocations == null) { + splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) + } else { + val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) + splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c16984259bc00..ed373d46f3198 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -213,6 +214,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ private var _heartbeater: Heartbeater = _ + private var _shuffleDriverComponents: ShuffleDriverComponents = _ private var _resources: scala.collection.immutable.Map[String, ResourceInformation] = _ /* ------------------------------------------------------------------------------------- * @@ -528,6 +530,14 @@ class SparkContext(config: SparkConf) extends Logging { executorEnvs ++= _conf.getExecutorEnv executorEnvs("SPARK_USER") = sparkUser + val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + _shuffleDriverComponents = maybeIO.head.driver() + _shuffleDriverComponents.initializeApplication().asScala.foreach { + case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) } + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) _heartbeatReceiver = env.rpcEnv.setupEndpoint( @@ -596,7 +606,7 @@ class SparkContext(config: SparkConf) extends Logging { _cleaner = if (_conf.get(CLEANER_REFERENCE_TRACKING)) { - Some(new ContextCleaner(this)) + Some(new ContextCleaner(this, _shuffleDriverComponents)) } else { None } @@ -1952,6 +1962,7 @@ class SparkContext(config: SparkConf) extends Logging { } _heartbeater = null } + _shuffleDriverComponents.cleanupApplication() if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ea79c7310349d..df30fd5c7f679 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -56,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable { private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] + private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics = + Predef.identity[TempShuffleReadMetrics] /** * Time taken on the executor to deserialize this task. @@ -187,11 +189,17 @@ class TaskMetrics private[spark] () extends Serializable { * be lost. */ private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized { - val readMetrics = new TempShuffleReadMetrics - tempShuffleReadMetrics += readMetrics + val tempShuffleMetrics = new TempShuffleReadMetrics + val readMetrics = _decorFunc(tempShuffleMetrics) + tempShuffleReadMetrics += tempShuffleMetrics readMetrics } + private[spark] def decorateTempShuffleReadMetrics( + decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized { + _decorFunc = decorFunc + } + /** * Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`. * This is expected to be called on executor heartbeat and at the end of a task. diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8d4910ff9f80e..becb4d5a90d0f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.metrics.GarbageCollectionMetrics import org.apache.spark.network.shuffle.Constants import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -798,6 +799,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_IO_PLUGIN_CLASS = + ConfigBuilder("spark.shuffle.io.plugin.class") + .doc("Name of the class to use for shuffle IO.") + .stringConf + .createWithDefault(classOf[DefaultShuffleDataIO].getName) + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..a61f9bd14ef2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,7 +24,9 @@ import scala.collection.mutable import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv +import org.apache.spark.api.shuffle.MapShuffleLocations import org.apache.spark.internal.config +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -33,7 +35,16 @@ import org.apache.spark.util.Utils * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { - /** Location where this task was run. */ + + /** + * Locations where this task stored shuffle blocks. + * + * May be null if the MapOutputTracker is not tracking the location of shuffle blocks, leaving it + * up to the implementation of shuffle plugins to do so. + */ + def mapShuffleLocations: MapShuffleLocations + + /** Location where the task was run. */ def location: BlockManagerId /** @@ -56,11 +67,31 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + // A temporary concession to the fact that we only expect implementations of shuffle provided by + // Spark to be storing shuffle locations in the driver, meaning we want to introduce as little + // serialization overhead as possible in such default cases. + // + // If more similar cases arise, consider adding a serialization API for these shuffle locations. + private val DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 0 + private val NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 1 + + /** + * Visible for testing. + */ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + apply(loc, DefaultMapShuffleLocations.get(loc), uncompressedSizes) + } + + def apply( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus( + loc, mapShuffleLocs, uncompressedSizes) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus( + loc, mapShuffleLocs, uncompressedSizes) } } @@ -91,41 +122,89 @@ private[spark] object MapStatus { math.pow(LOG_BASE, compressedSize & 0xFF).toLong } } -} + def writeLocations( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + out: ObjectOutput): Unit = { + if (mapShuffleLocs != null) { + out.writeBoolean(true) + if (mapShuffleLocs.isInstanceOf[DefaultMapShuffleLocations] + && mapShuffleLocs.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId == loc) { + out.writeByte(MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) + } else { + out.writeByte(MapStatus.NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) + out.writeObject(mapShuffleLocs) + } + } else { + out.writeBoolean(false) + } + loc.writeExternal(out) + } + + def readLocations(in: ObjectInput): (BlockManagerId, MapShuffleLocations) = { + if (in.readBoolean()) { + val locId = in.readByte() + if (locId == MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) { + val blockManagerId = BlockManagerId(in) + (blockManagerId, DefaultMapShuffleLocations.get(blockManagerId)) + } else { + val mapShuffleLocations = in.readObject().asInstanceOf[MapShuffleLocations] + val blockManagerId = BlockManagerId(in) + (blockManagerId, mapShuffleLocations) + } + } else { + val blockManagerId = BlockManagerId(in) + (blockManagerId, null) + } + } +} /** * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is * represented using a single byte. * - * @param loc location where the task is being executed. + * @param loc Location were the task is being executed. + * @param mapShuffleLocs locations where the task stored its shuffle blocks - may be null. * @param compressedSizes size of the blocks, indexed by reduce partition id. */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, + private[this] var mapShuffleLocs: MapShuffleLocations, private[this] var compressedSizes: Array[Byte]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null, null.asInstanceOf[Array[Byte]]) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this( + loc: BlockManagerId, + mapShuffleLocations: MapShuffleLocations, + uncompressedSizes: Array[Long]) { + this( + loc, + mapShuffleLocations, + uncompressedSizes.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc + override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - loc.writeExternal(out) + MapStatus.writeLocations(loc, mapShuffleLocs, out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in) + loc = deserializedLoc + mapShuffleLocs = deserializedMapShuffleLocs val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -138,6 +217,7 @@ private[spark] class CompressedMapStatus( * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed + * @param mapShuffleLocs location where the task stored shuffle blocks - may be null * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty * @param avgSize average size of the non-empty and non-huge blocks @@ -145,6 +225,7 @@ private[spark] class CompressedMapStatus( */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, + private[this] var mapShuffleLocs: MapShuffleLocations, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, @@ -155,10 +236,12 @@ private[spark] class HighlyCompressedMapStatus private ( require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc + override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -172,7 +255,7 @@ private[spark] class HighlyCompressedMapStatus private ( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - loc.writeExternal(out) + MapStatus.writeLocations(loc, mapShuffleLocs, out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -183,7 +266,9 @@ private[spark] class HighlyCompressedMapStatus private ( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in) + loc = deserializedLoc + mapShuffleLocs = deserializedMapShuffleLocs emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -199,7 +284,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -239,7 +327,12 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + new HighlyCompressedMapStatus( + loc, + mapShuffleLocs, + numNonEmptyBlocks, + emptyBlocks, + avgSize, + hugeBlockSizes) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 39691069bf5f6..f886fe7d9e598 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,7 +31,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSe import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} -import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.serializers.{ExternalizableSerializer, JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap @@ -151,6 +151,8 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[CompressedMapStatus], new ExternalizableSerializer()) + kryo.register(classOf[HighlyCompressedMapStatus], new ExternalizableSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) @@ -486,8 +488,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[CompressedMapStatus], - classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Boolean]], diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c7843710413dd..530c3694ad1ec 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,10 +17,18 @@ package org.apache.spark.shuffle +import java.io.InputStream + +import scala.collection.JavaConverters._ + import org.apache.spark._ +import org.apache.spark.api.java.Optional +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -34,34 +42,68 @@ private[spark] class BlockStoreShuffleReader[K, C]( endPartition: Int, context: TaskContext, readMetrics: ShuffleReadMetricsReporter, + shuffleReadSupport: ShuffleReadSupport, serializerManager: SerializerManager = SparkEnv.get.serializerManager, - blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + sparkConf: SparkConf = SparkEnv.get.conf) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency + private val compressionCodec = CompressionCodec.createCodec(sparkConf) + + private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS) + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, - SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - readMetrics).toCompletionIterator + val streamsIterator = + shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] { + override def iterator: Iterator[ShuffleBlockInfo] = { + mapOutputTracker + .getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) + .flatMap { shuffleLocationInfo => + shuffleLocationInfo._2.map { blockInfo => + val block = blockInfo._1.asInstanceOf[ShuffleBlockId] + new ShuffleBlockInfo( + block.shuffleId, + block.mapId, + block.reduceId, + blockInfo._2, + Optional.ofNullable(shuffleLocationInfo._1.orNull)) + } + } + } + }.asJava).iterator() - val serializerInstance = dep.serializer.newInstance() + val retryingWrappedStreams = new Iterator[InputStream] { + override def hasNext: Boolean = streamsIterator.hasNext - // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + override def next(): InputStream = { + var returnStream: InputStream = null + while (streamsIterator.hasNext && returnStream == null) { + if (shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) { + // The default implementation checks for corrupt streams, so it will already have + // decompressed/decrypted the bytes + returnStream = streamsIterator.next() + } else { + val nextStream = streamsIterator.next() + returnStream = if (compressShuffle) { + compressionCodec.compressedInputStream( + serializerManager.wrapForEncryption(nextStream)) + } else { + serializerManager.wrapForEncryption(nextStream) + } + } + } + if (returnStream == null) { + throw new IllegalStateException("Expected shuffle reader iterator to return a stream") + } + returnStream + } + } + + val serializerInstance = dep.serializer.newInstance() + val recordIter = retryingWrappedStreams.flatMap { wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala new file mode 100644 index 0000000000000..9b9b8508e88aa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.io + +import java.io.InputStream + +import scala.collection.JavaConverters._ + +import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext} +import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} +import org.apache.spark.internal.config +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} + +class DefaultShuffleReadSupport( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, + serializerManager: SerializerManager, + conf: SparkConf) extends ShuffleReadSupport { + + private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 + private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) + private val maxBlocksInFlightPerAddress = + conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) + private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT) + + override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): + java.lang.Iterable[InputStream] = { + + val iterableToReturn = if (blockMetadata.asScala.isEmpty) { + Iterable.empty + } else { + val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId) + .foldLeft(Int.MaxValue, 0) { + case ((min, max), elem) => (math.min(min, elem), math.max(max, elem)) + } + val shuffleId = blockMetadata.asScala.head.getShuffleId + new ShuffleBlockFetcherIterable( + TaskContext.get(), + blockManager, + serializerManager, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(), + minReduceId, + maxReduceId, + shuffleId, + mapOutputTracker + ) + } + iterableToReturn.asJava + } +} + +private class ShuffleBlockFetcherIterable( + context: TaskContext, + blockManager: BlockManager, + serializerManager: SerializerManager, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorruption: Boolean, + shuffleMetrics: ShuffleReadMetricsReporter, + minReduceId: Int, + maxReduceId: Int, + shuffleId: Int, + mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] { + + override def iterator: Iterator[InputStream] = { + new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1) + .map { shuffleLocationInfo => + val defaultShuffleLocation = shuffleLocationInfo._1 + .get.asInstanceOf[DefaultMapShuffleLocations] + (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) + }, + serializerManager.wrapStream, + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorruption, + shuffleMetrics).toCompletionIterator + } + +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b59fa8e8a3ccd..947753f6b40e8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,9 +19,13 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.util.Utils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -68,6 +72,8 @@ import org.apache.spark.shuffle._ */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + import SortShuffleManager._ + if (!conf.getBoolean("spark.shuffle.spill", true)) { logWarning( "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + @@ -79,6 +85,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -118,7 +126,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics) + startPartition, + endPartition, + context, + metrics, + shuffleExecutorComponents.reads()) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -134,23 +146,24 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), unsafeShuffleHandle, mapId, context, env.conf, - metrics) + metrics, + shuffleExecutorComponents.writes()) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, env.conf, - metrics) + metrics, + shuffleExecutorComponents.writes()) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + new SortShuffleWriter( + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes()) } } @@ -205,6 +218,21 @@ private[spark] object SortShuffleManager extends Logging { true } } + + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + val executorComponents = maybeIO.head.executor() + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX) + .toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) + executorComponents + } } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 16058de8bf3ff..1fcae684b0052 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -18,18 +18,18 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( shuffleBlockResolver: IndexShuffleBlockResolver, handle: BaseShuffleHandle[K, V, C], mapId: Int, - context: TaskContext) + context: TaskContext, + writeSupport: ShuffleWriteSupport) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,18 +64,14 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) - val tmp = Utils.tempFileWith(output) - try { - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) - } finally { - if (tmp.exists() && !tmp.delete()) { - logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") - } - } + val mapOutputWriter = writeSupport.createMapOutputWriter( + dep.shuffleId, mapId, dep.partitioner.numPartitions) + val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + val mapLocations = mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus( + blockManager.shuffleServerId, + mapLocations.orNull(), + partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d188bdd912e5e..97b99e08d9ca9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -132,12 +132,14 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } + val blockManagerIdCacheSize = 10000 + /** * The max cache size is hardcoded to 10000, since the size of a BlockManagerId * object is about 48B, the total memory cost should be below 1MB which is feasible. */ val blockManagerIdCache = CacheBuilder.newBuilder() - .maximumSize(10000) + .maximumSize(blockManagerIdCacheSize) .build(new CacheLoader[BlockManagerId, BlockManagerId]() { override def load(id: BlockManagerId) = id }) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 17390f9c60e79..f9f4e3594e4f9 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter /** * A class for writing JVM objects directly to a file on disk. This class allows data to be appended @@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter( writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream - with Logging { + with Logging + with PairsWriter { /** * Guards against close calls, e.g. from a wrapping stream. diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c89d5cc971d2a..22fc4da97a5b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -76,7 +76,7 @@ final class ShuffleBlockFetcherIterator( detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) - extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { + extends Iterator[InputStream] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -399,7 +399,7 @@ final class ShuffleBlockFetcherIterator( * * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, InputStream) = { + override def next(): InputStream = { if (!hasNext) { throw new NoSuchElementException() } @@ -497,7 +497,6 @@ final class ShuffleBlockFetcherIterator( in.close() } } - case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } @@ -510,6 +509,7 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] +<<<<<<< HEAD (currentResult.blockId, new BufferReleasingInputStream( input, @@ -517,10 +517,19 @@ final class ShuffleBlockFetcherIterator( currentResult.blockId, currentResult.address, detectCorrupt && streamCompressedOrEncrypted)) +======= + val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId] + new BufferReleasingInputStream(input, this) + } + + // for testing only + def getCurrentBlock(): ShuffleBlockId = { + currentResult.blockId.asInstanceOf[ShuffleBlockId] +>>>>>>> b35d23845c... [SPARK-25299] shuffle reader API (#523) } - def toCompletionIterator: Iterator[(BlockId, InputStream)] = { - CompletionIterator[(BlockId, InputStream), this.type](this, + def toCompletionIterator: Iterator[InputStream] = { + CompletionIterator[InputStream, this.type](this, onCompleteCallback.onComplete(context)) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bed50865e7be4..87063819d571a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel} +import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.security.SecureRandom @@ -394,10 +394,14 @@ private[spark] object Utils extends Logging { def copyFileStreamNIO( input: FileChannel, - output: FileChannel, + output: WritableByteChannel, startPosition: Long, bytesToCopy: Long): Unit = { - val initialPos = output.position() + val outputInitialState = output match { + case outputFileChannel: FileChannel => + Some((outputFileChannel.position(), outputFileChannel)) + case _ => None + } var count = 0L // In case transferTo method transferred less data than we have required. while (count < bytesToCopy) { @@ -412,15 +416,17 @@ private[spark] object Utils extends Logging { // kernel version 2.6.32, this issue can be seen in // https://bugs.openjdk.java.net/browse/JDK-7052359 // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = output.position() - val expectedPos = initialPos + bytesToCopy - assert(finalPos == expectedPos, - s""" - |Current position $finalPos do not equal to expected position $expectedPos - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + outputInitialState.foreach { case (initialPos, outputFileChannel) => + val finalPos = outputFileChannel.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3f3b7d20eb169..13132c2801ed9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -670,11 +671,9 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Write all the data added into this ExternalSorter into a file in the disk store. This is - * called by the SortShuffleWriter. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project. + * We should figure out an alternative way to test that so that we can remove this otherwise + * unused code path. */ def writePartitionedFile( blockId: BlockId, @@ -718,6 +717,88 @@ private[spark] class ExternalSorter[K, V, C]( lengths } + /** + * Write all the data added into this ExternalSorter into a map output writer that pushes bytes + * to some arbitrary backing store. This is called by the SortShuffleWriter. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedMapOutput( + shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = { + // Track location of each range in the map output + val lengths = new Array[Long](numPartitions) + if (spills.isEmpty) { + // Case where we only have in-memory data + val collection = if (aggregator.isDefined) map else buffer + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + while (it.hasNext()) { + val partitionId = it.nextPartition() + var partitionWriter: ShufflePartitionWriter = null + var partitionPairsWriter: ShufflePartitionPairsWriter = null + try { + partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) + val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) + partitionPairsWriter = new ShufflePartitionPairsWriter( + partitionWriter, + serializerManager, + serInstance, + blockId, + context.taskMetrics().shuffleWriteMetrics) + while (it.hasNext && it.nextPartition() == partitionId) { + it.writeNext(partitionPairsWriter) + } + } finally { + if (partitionPairsWriter != null) { + partitionPairsWriter.close() + } + } + if (partitionWriter != null) { + lengths(partitionId) = partitionWriter.getNumBytesWritten + } + } + } else { + // We must perform merge-sort; get an iterator by partition and write everything directly. + for ((id, elements) <- this.partitionedIterator) { + // The contract for the plugin is that we will ask for a writer for every partition + // even if it's empty. However, the external sorter will return non-contiguous + // partition ids. So this loop "backfills" the empty partitions that form the gaps. + + // The algorithm as a whole is correct because the partition ids are returned by the + // iterator in ascending order. + val blockId = ShuffleBlockId(shuffleId, mapId, id) + var partitionWriter: ShufflePartitionWriter = null + var partitionPairsWriter: ShufflePartitionPairsWriter = null + try { + partitionWriter = mapOutputWriter.getPartitionWriter(id) + partitionPairsWriter = new ShufflePartitionPairsWriter( + partitionWriter, + serializerManager, + serInstance, + blockId, + context.taskMetrics().shuffleWriteMetrics) + if (elements.hasNext) { + for (elem <- elements) { + partitionPairsWriter.write(elem._1, elem._2) + } + } + } finally { + if (partitionPairsWriter!= null) { + partitionPairsWriter.close() + } + } + if (partitionWriter != null) { + lengths(id) = partitionWriter.getNumBytesWritten + } + } + } + + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) + + lengths + } + def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() @@ -781,7 +862,7 @@ private[spark] class ExternalSorter[K, V, C]( val inMemoryIterator = new WritablePartitionedIterator { private[this] var cur = if (upstream.hasNext) upstream.next() else null - def writeNext(writer: DiskBlockObjectWriter): Unit = { + def writeNext(writer: PairsWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (upstream.hasNext) upstream.next() else null } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala new file mode 100644 index 0000000000000..9d7c209f242e1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +private[spark] trait PairsWriter { + + def write(key: Any, value: Any): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala new file mode 100644 index 0000000000000..8538a78b377c8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.{Closeable, FilterOutputStream, OutputStream} + +import org.apache.spark.api.shuffle.ShufflePartitionWriter +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.storage.BlockId + +/** + * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an + * arbitrary partition writer instead of writing to local disk through the block manager. + */ +private[spark] class ShufflePartitionPairsWriter( + partitionWriter: ShufflePartitionWriter, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + blockId: BlockId, + writeMetrics: ShuffleWriteMetricsReporter) + extends PairsWriter with Closeable { + + private var isOpen = false + private var partitionStream: OutputStream = _ + private var wrappedStream: OutputStream = _ + private var objOut: SerializationStream = _ + private var numRecordsWritten = 0 + private var curNumBytesWritten = 0L + + override def write(key: Any, value: Any): Unit = { + if (!isOpen) { + open() + isOpen = true + } + objOut.writeKey(key) + objOut.writeValue(value) + writeMetrics.incRecordsWritten(1) + } + + private def open(): Unit = { + partitionStream = partitionWriter.openStream + wrappedStream = serializerManager.wrapStream(blockId, partitionStream) + objOut = serializerInstance.serializeStream(wrappedStream) + } + + override def close(): Unit = { + if (isOpen) { + objOut.close() + objOut = null + wrappedStream = null + partitionStream = null + isOpen = false + updateBytesWritten() + } + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + private def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + private def updateBytesWritten(): Unit = { + val numBytesWritten = partitionWriter.getNumBytesWritten + val bytesWrittenDiff = numBytesWritten - curNumBytesWritten + writeMetrics.incBytesWritten(bytesWrittenDiff) + curNumBytesWritten = numBytesWritten + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index dd7f68fd038d2..da8d58d05b6b9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: DiskBlockObjectWriter): Unit = { + def writeNext(writer: PairsWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection { * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: DiskBlockObjectWriter): Unit + def writeNext(writer: PairsWriter): Unit def hasNext(): Boolean diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 88125a6b93ade..3c172a027ca0f 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -19,8 +19,10 @@ import java.io.*; import java.nio.ByteBuffer; +import java.nio.file.Files; import java.util.*; +import org.mockito.stubbing.Answer; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -39,6 +41,7 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -53,6 +56,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -65,6 +69,7 @@ public class UnsafeShuffleWriterSuite { + static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int NUM_PARTITITONS = 4; TestMemoryManager memoryManager; TaskMemoryManager taskMemoryManager; @@ -85,6 +90,7 @@ public class UnsafeShuffleWriterSuite { @After public void tearDown() { + TaskContext$.MODULE$.unset(); Utils.deleteRecursively(tempDir); final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { @@ -132,14 +138,28 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - doAnswer(invocationOnMock -> { + + Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; - mergedOutputFile.delete(); - tmp.renameTo(mergedOutputFile); + if (!mergedOutputFile.delete()) { + throw new RuntimeException("Failed to delete old merged output file."); + } + if (tmp != null) { + Files.move(tmp.toPath(), mergedOutputFile.toPath()); + } else if (!mergedOutputFile.createNewFile()) { + throw new RuntimeException("Failed to create empty merged output file."); + } return null; - }).when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + }; + + doAnswer(renameTempAnswer) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + + doAnswer(renameTempAnswer) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -151,6 +171,11 @@ public void setUp() throws IOException { when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + when(blockManager.shuffleServerId()).thenReturn(BlockManagerId.apply( + "0", "localhost", 9099, Option.empty())); + + TaskContext$.MODULE$.setTaskContext(taskContext); } private UnsafeShuffleWriter createWriter( @@ -158,14 +183,13 @@ private UnsafeShuffleWriter createWriter( conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter<>( blockManager, - shuffleBlockResolver, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics() - ); + taskContext.taskMetrics().shuffleWriteMetrics(), + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); } private void assertSpillFilesWereCleanedUp() { @@ -444,10 +468,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro } private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); + memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { + for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2<>(i, i)); } writer.write(dataToWrite.iterator()); @@ -519,13 +543,13 @@ public void testPeakMemoryUsed() throws Exception { final UnsafeShuffleWriter writer = new UnsafeShuffleWriter<>( blockManager, - shuffleBlockResolver, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics()); + taskContext.taskMetrics().shuffleWriteMetrics(), + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 62824a5bec9d1..28cbeeda7a88d 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -210,7 +210,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ - private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { + private class SaveAccumContextCleaner(sc: SparkContext) extends + ContextCleaner(sc, null) { private val accumsRegistered = new ArrayBuffer[Long] override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d86975964b558..8fcbc845d1a7b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MA import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { @@ -67,10 +68,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getMapSizesByExecutorId(10, 0) + val statuses = tracker.getMapSizesByShuffleLocation(10, 0) assert(statuses.toSet === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + Seq( + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (Some(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))), + ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() @@ -90,11 +94,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(tracker.getMapSizesByShuffleLocation(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) + assert(tracker.getMapSizesByShuffleLocation(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -121,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByShuffleLocation(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -143,24 +147,26 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) slaveTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === + Seq( + (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() @@ -261,8 +267,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // being sent. masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => + val bmId = BlockManagerId("999", "mps", 1000) masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + bmId, + DefaultMapShuffleLocations.get(bmId), + Array.fill[Long](4000000)(0))) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -315,12 +324,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByShuffleLocation(10, 0, 4) + .map(x => (x._1.get, x._2)).toSeq === Seq( - (BlockManagerId("a", "hostA", 1000), - Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), - (BlockManagerId("b", "hostB", 1000), - Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))), + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))) ) ) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 8b1084a8edc76..1d2713151f505 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} -import org.apache.spark.util.{MutablePair, Utils} +import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -73,7 +73,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } @@ -112,7 +112,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -137,7 +137,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -368,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) val writer1 = manager.getWriter[Int, Int]( shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) - val data1 = (1 to 10).map { x => x -> x} + val data1 = (1 to 10).map { x => x -> x } // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently @@ -383,13 +383,17 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int])( + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( iter: Iterator[(Int, Int)]): Option[MapStatus] = { + TaskContext.setTaskContext(taskContext) val files = writer.write(iter) - writer.stop(true) + val status = writer.stop(true) + TaskContext.unset + status } val interleaver = new InterleaveIterators( - data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) val (mapOutput1, mapOutput2) = interleaver.run() // check that we can read the map output and it has the right data @@ -405,12 +409,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val taskContext = new TaskContextImpl( 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) manager.unregisterShuffle(0) + TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d58ee4e651e19..51a65cc2a6ba8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,12 +29,14 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.shuffle.MapShuffleLocations import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -713,8 +715,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -740,8 +742,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -779,11 +783,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(ExecutorLost("exec-hostA", event)) if (expectFileLoss) { intercept[MetadataFetchFailedException] { - mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0) } } else { - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) } } } @@ -1076,8 +1080,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( @@ -1206,10 +1212,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 1) + .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( @@ -1399,8 +1409,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi Success, makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1554,7 +1564,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi reduceIdx <- reduceIdxs } { // this would throw an exception if the map status hadn't been registered - val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + val statuses = mapOutputTracker.getMapSizesByShuffleLocation(stage, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) @@ -1606,7 +1616,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // check that we have all the map output for stage 0 (0 until reduceRdd.partitions.length).foreach { reduceIdx => - val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + val statuses = mapOutputTracker.getMapSizesByShuffleLocation(0, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) @@ -1805,8 +1815,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -2068,8 +2078,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2114,8 +2124,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2278,8 +2288,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure @@ -2294,8 +2304,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result // Stage 1 should now be running as task set 3; make its first task succeed @@ -2303,8 +2314,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(3), Seq( (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2342,8 +2353,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. @@ -2368,8 +2379,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2982,6 +2993,14 @@ object DAGSchedulerSuite { def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) + + def makeShuffleLocation(host: String): MapShuffleLocations = { + DefaultMapShuffleLocations.get(makeBlockManagerId(host)) + } + + def makeMaybeShuffleLocation(host: String): Option[MapShuffleLocations] = { + Some(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) + } } object FailThisAttempt { diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index c1e7fb9a1db16..3c786c0927bc6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.internal.config import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -61,7 +62,11 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val bmId = BlockManagerId("a", "b", 10) + val status = MapStatus( + bmId, + DefaultMapShuffleLocations.get(bmId), + sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -75,7 +80,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,11 +91,13 @@ class MapStatusSuite extends SparkFunSuite { test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) - val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val bmId = BlockManagerId("a", "b", 10) + val loc = DefaultMapShuffleLocations.get(bmId) + val status = MapStatus(bmId, loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == loc) + assert(status1.location == loc.getBlockManagerId) + assert(status1.mapShuffleLocations == loc) for (i <- 0 until 3000) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -108,11 +115,13 @@ class MapStatusSuite extends SparkFunSuite { val sizes = (0L to 3000L).toArray val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length - val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val bmId = BlockManagerId("a", "b", 10) + val loc = DefaultMapShuffleLocations.get(bmId) + val status = MapStatus(bmId, loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == loc) + assert(status1.location === bmId) + assert(status1.mapShuffleLocations === loc) for (i <- 0 until threshold) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -165,7 +174,8 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val bmId = BlockManagerId("exec-0", "host-0", 100) + val status1 = MapStatus(bmId, DefaultMapShuffleLocations.get(bmId), sizes) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 577d77e890d78..ec828db9391da 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -191,7 +191,8 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa shuffleId <- shuffleIds reduceIdx <- (0 until nParts) } { - val statuses = taskScheduler.mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceIdx) + val statuses = taskScheduler.mapOutputTracker.getMapSizesByShuffleLocation( + shuffleId, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 2442670b6d3f0..63f6942dee184 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -36,8 +36,9 @@ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Kryo._ import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ThreadUtils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") @@ -350,8 +351,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) + val bmId = BlockManagerId("exec-1", "host", 1234) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize(HighlyCompressedMapStatus( + bmId, DefaultMapShuffleLocations.get(bmId), blockSizes)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a790..6468914bf3185 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -21,12 +21,19 @@ import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.internal.config +import org.apache.spark.io.CompressionCodec import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.BlockId /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -78,11 +85,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Create a buffer with some randomly generated key-value pairs to use as the shuffle data // from each mappers (all mappers return the same shuffle data). val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + val compressionCodec = CompressionCodec.createCodec(testConf) + val compressedOutputStream = compressionCodec.compressedOutputStream(byteOutputStream) + val serializationStream = serializer.newInstance().serializeStream(compressedOutputStream) (0 until keyValuePairsPerMap).foreach { i => serializationStream.writeKey(i) serializationStream.writeValue(2*i) } + compressedOutputStream.close() // Setup the mocked BlockManager to return RecordingManagedBuffers. val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) @@ -101,16 +111,20 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) - } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) } + val blocksToRetrieve = Seq( + (Option.apply(DefaultMapShuffleLocations.get(localBlockManagerId)), shuffleBlockIdsAndSizes)) + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)) + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + blocksToRetrieve.iterator + } + }) // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { @@ -124,19 +138,23 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val serializerManager = new SerializerManager( serializer, new SparkConf() - .set(config.SHUFFLE_COMPRESS, false) + .set(config.SHUFFLE_COMPRESS, true) .set(config.SHUFFLE_SPILL_COMPRESS, false)) val taskContext = TaskContext.empty() + TaskContext.setTaskContext(taskContext) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + + val shuffleReadSupport = + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, taskContext, metrics, + shuffleReadSupport, serializerManager, - blockManager, mapOutputTracker) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) @@ -147,5 +165,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext assert(buffer.callsToRetain === 1) assert(buffer.callsToRelease === 1) } + TaskContext.unset() } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala new file mode 100644 index 0000000000000..e8372c0458600 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util + +import com.google.common.collect.ImmutableMap + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} +import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport + +class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { + test(s"test serialization of shuffle initialization conf to executors") { + val testConf = new SparkConf() + .setAppName("testing") + .setMaster("local-cluster[2,1,1024]") + .set(SHUFFLE_IO_PLUGIN_CLASS, "org.apache.spark.shuffle.TestShuffleDataIO") + + sc = new SparkContext(testConf) + + sc.parallelize(Seq((1, "one"), (2, "two"), (3, "three")), 3) + .groupByKey() + .collect() + } +} + +class TestShuffleDriverComponents extends ShuffleDriverComponents { + override def initializeApplication(): util.Map[String, String] = + ImmutableMap.of("test-key", "test-value") + + override def cleanupApplication(): Unit = {} + + override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = {} +} + +class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + override def driver(): ShuffleDriverComponents = new TestShuffleDriverComponents() + + override def executor(): ShuffleExecutorComponents = + new TestShuffleExecutorComponents(sparkConf) +} + +class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { + override def initializeExecutor(appId: String, execId: String, + extraConfigs: util.Map[String, String]): Unit = { + assert(extraConfigs.get("test-key") == "test-value") + } + + override def writes(): ShuffleWriteSupport = { + val blockManager = SparkEnv.get.blockManager + val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) + new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + } + + override def reads(): ShuffleReadSupport = { + val blockManager = SparkEnv.get.blockManager + val mapOutputTracker = SparkEnv.get.mapOutputTracker + val serializerManager = SparkEnv.get.serializerManager + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala new file mode 100644 index 0000000000000..4f5bb264170de --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import java.io.{File, FileOutputStream} + +import com.google.common.io.CountingOutputStream +import org.apache.commons.io.FileUtils +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import scala.util.Random + +import org.apache.spark.{Aggregator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext} +import org.apache.spark.api.shuffle.ShuffleLocation +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.metrics.source.Source +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException} +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockId} +import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils} + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { + + // this is only used to retrieve the aggregator/sorters/serializers, + // so it shouldn't affect the performance significantly + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: + ShuffleDependency[String, String, String] = _ + // only used to retrieve info about the maps at the beginning, doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _ + // this is only used when initializing the BlockManager, so doesn't affect perf + @Mock(answer = RETURNS_SMART_NULLS) private var blockManagerMaster: BlockManagerMaster = _ + // this is only used when initiating the BlockManager, for comms between master and executor + @Mock(answer = RETURNS_SMART_NULLS) private var rpcEnv: RpcEnv = _ + @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + + private var tempDir: File = _ + + private val NUM_MAPS = 5 + private val DEFAULT_DATA_STRING_SIZE = 5 + private val TEST_DATA_SIZE = 10000000 + private val SMALLER_DATA_SIZE = 2000000 + private val MIN_NUM_ITERS = 10 + + private val executorId = "0" + private val localPort = 17000 + private val remotePort = 17002 + + private val defaultConf = new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.spill.compress", "false") + .set("spark.authenticate", "false") + .set("spark.app.id", "test-app") + private val serializer = new KryoSerializer(defaultConf) + private val serializerManager = new SerializerManager(serializer, defaultConf) + private val execBlockManagerId = BlockManagerId(executorId, "localhost", localPort) + private val remoteBlockManagerId = BlockManagerId(executorId, "localhost", remotePort) + private val transportConf = SparkTransportConf.fromSparkConf(defaultConf, "shuffle") + private val securityManager = new org.apache.spark.SecurityManager(defaultConf) + protected val memoryManager = new TestMemoryManager(defaultConf) + + class TestBlockManager(transferService: BlockTransferService, + blockManagerMaster: BlockManagerMaster, + dataFile: File, + fileLength: Long, + offset: Long) extends BlockManager( + executorId, + rpcEnv, + blockManagerMaster, + serializerManager, + defaultConf, + memoryManager, + null, + null, + transferService, + null, + 1) { + blockManagerId = execBlockManagerId + + override def getBlockData(blockId: BlockId): ManagedBuffer = { + new FileSegmentManagedBuffer( + transportConf, + dataFile, + offset, + fileLength + ) + } + } + + private var blockManager : BlockManager = _ + private var externalBlockManager: BlockManager = _ + + def getTestBlockManager( + port: Int, + dataFile: File, + dataFileLength: Long, + offset: Long): TestBlockManager = { + val shuffleClient = new NettyBlockTransferService( + defaultConf, + securityManager, + "localhost", + "localhost", + port, + 1 + ) + new TestBlockManager(shuffleClient, + blockManagerMaster, + dataFile, + dataFileLength, + offset) + } + + def initializeServers(dataFile: File, dataFileLength: Long, readOffset: Long = 0): Unit = { + MockitoAnnotations.initMocks(this) + when(blockManagerMaster.registerBlockManager( + any[BlockManagerId], any[Long], any[Long], any[RpcEndpointRef])).thenReturn(null) + when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef) + blockManager = getTestBlockManager(localPort, dataFile, dataFileLength, readOffset) + blockManager.initialize(defaultConf.getAppId) + externalBlockManager = getTestBlockManager(remotePort, dataFile, dataFileLength, readOffset) + externalBlockManager.initialize(defaultConf.getAppId) + } + + def stopServers(): Unit = { + blockManager.stop() + externalBlockManager.stop() + } + + def setupReader( + dataFile: File, + dataFileLength: Long, + fetchLocal: Boolean, + aggregator: Option[Aggregator[String, String, String]] = None, + sorter: Option[Ordering[String]] = None): BlockStoreShuffleReader[String, String] = { + SparkEnv.set(new SparkEnv( + "0", + null, + serializer, + null, + serializerManager, + mapOutputTracker, + null, + null, + blockManager, + null, + null, + null, + null, + defaultConf + )) + + val shuffleHandle = new BaseShuffleHandle( + shuffleId = 0, + numMaps = NUM_MAPS, + dependency = dependency) + + val taskContext = new TestTaskContext + TaskContext.setTaskContext(taskContext) + + var dataBlockId = execBlockManagerId + if (!fetchLocal) { + dataBlockId = remoteBlockManagerId + } + + when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1)) + .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + def answer(invocationOnMock: InvocationOnMock): + Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => + val shuffleBlockId = ShuffleBlockId(0, mapId, 0) + (shuffleBlockId, dataFileLength) + } + Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + .toIterator + } + }) + + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(aggregator) + when(dependency.keyOrdering).thenReturn(sorter) + + val readSupport = new DefaultShuffleReadSupport( + blockManager, + mapOutputTracker, + serializerManager, + defaultConf) + + new BlockStoreShuffleReader[String, String]( + shuffleHandle, + 0, + 1, + taskContext, + taskContext.taskMetrics().createTempShuffleReadMetrics(), + readSupport, + serializerManager, + mapOutputTracker + ) + } + + def generateDataOnDisk(size: Int, file: File, recordOffset: Int): (Long, Long) = { + // scalastyle:off println + println("Generating test data with num records: " + size) + + val dataOutput = new ManualCloseFileOutputStream(file) + val random = new Random(123) + val serializerInstance = serializer.newInstance() + + var countingOutput = new CountingOutputStream(dataOutput) + var serializedOutput = serializerInstance.serializeStream(countingOutput) + var readOffset = 0L + try { + (1 to size).foreach { i => { + if (i % 1000000 == 0) { + println("Wrote " + i + " test data points") + } + if (i == recordOffset) { + serializedOutput.close() + readOffset = countingOutput.getCount + countingOutput = new CountingOutputStream(dataOutput) + serializedOutput = serializerInstance.serializeStream(countingOutput) + } + val x = random.alphanumeric.take(DEFAULT_DATA_STRING_SIZE).mkString + serializedOutput.writeKey(x) + serializedOutput.writeValue(x) + }} + } finally { + serializedOutput.close() + dataOutput.manualClose() + } + (countingOutput.getCount, readOffset) + // scalastyle:off println + } + + class TestDataFile(file: File, length: Long, offset: Long) { + def getFile(): File = file + def getLength(): Long = length + def getOffset(): Long = offset + } + + def runWithTestDataFile(size: Int, readOffset: Int = 0)(func: TestDataFile => Unit): Unit = { + val tempDataFile = File.createTempFile("test-data", "", tempDir) + val dataFileLengthAndOffset = generateDataOnDisk(size, tempDataFile, readOffset) + initializeServers(tempDataFile, dataFileLengthAndOffset._1, dataFileLengthAndOffset._2) + func(new TestDataFile(tempDataFile, dataFileLengthAndOffset._1, dataFileLengthAndOffset._2)) + tempDataFile.delete() + stopServers() + } + + def addBenchmarkCase( + benchmark: Benchmark, + name: String, + shuffleReaderSupplier: => BlockStoreShuffleReader[String, String], + assertSize: Option[Int] = None): Unit = { + benchmark.addTimerCase(name) { timer => + val reader = shuffleReaderSupplier + timer.startTiming() + val numRead = reader.read().length + timer.stopTiming() + assertSize.foreach(size => assert(numRead == size)) + } + } + + def runLargeDatasetTests(): Unit = { + runWithTestDataFile(TEST_DATA_SIZE) { testDataFile => + val baseBenchmark = + new Benchmark("no aggregation or sorting", + TEST_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase( + baseBenchmark, + "local fetch", + setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = true), + assertSize = Option.apply(TEST_DATA_SIZE * NUM_MAPS)) + addBenchmarkCase( + baseBenchmark, + "remote rpc fetch", + setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = false), + assertSize = Option.apply(TEST_DATA_SIZE * NUM_MAPS)) + baseBenchmark.run() + } + } + + def runSmallDatasetTests(): Unit = { + runWithTestDataFile(SMALLER_DATA_SIZE) { testDataFile => + def createCombiner(i: String): String = i + def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j + def mergeCombiners(i: String, j: String): String = + if (Ordering.String.compare(i, j) > 0) i else j + val aggregator = + new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners) + val aggregationBenchmark = + new Benchmark("with aggregation", + SMALLER_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase( + aggregationBenchmark, + "local fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = true, + aggregator = Some(aggregator))) + addBenchmarkCase( + aggregationBenchmark, + "remote rpc fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = false, + aggregator = Some(aggregator))) + aggregationBenchmark.run() + + + val sortingBenchmark = + new Benchmark("with sorting", + SMALLER_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase( + sortingBenchmark, + "local fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = true, + sorter = Some(Ordering.String)), + assertSize = Option.apply(SMALLER_DATA_SIZE * NUM_MAPS)) + addBenchmarkCase( + sortingBenchmark, + "remote rpc fetch", + setupReader( + testDataFile.getFile(), + testDataFile.getLength(), + fetchLocal = false, + sorter = Some(Ordering.String)), + assertSize = Option.apply(SMALLER_DATA_SIZE * NUM_MAPS)) + sortingBenchmark.run() + } + } + + def runSeekTests(): Unit = { + runWithTestDataFile(SMALLER_DATA_SIZE, readOffset = SMALLER_DATA_SIZE) { testDataFile => + val seekBenchmark = + new Benchmark("with seek", + SMALLER_DATA_SIZE, + minNumIters = MIN_NUM_ITERS, + output = output) + + addBenchmarkCase( + seekBenchmark, + "seek to last record", + setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = false), + Option.apply(NUM_MAPS)) + seekBenchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + tempDir = Utils.createTempDir(null, "shuffle") + + runBenchmark("BlockStoreShuffleReader reader") { + runLargeDatasetTests() + runSmallDatasetTests() + runSeekTests() + } + + FileUtils.deleteDirectory(tempDir) + } + + // We cannot mock the TaskContext because it taskMetrics() gets called at every next() + // call on the reader, and Mockito will try to log all calls to taskMetrics(), thus OOM-ing + // the test + class TestTaskContext extends TaskContext { + private val metrics: TaskMetrics = new TaskMetrics + private val testMemManager = new TestMemoryManager(defaultConf) + private val taskMemManager = new TaskMemoryManager(testMemManager, 0) + testMemManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES) + override def isCompleted(): Boolean = false + override def isInterrupted(): Boolean = false + override def addTaskCompletionListener(listener: TaskCompletionListener): + TaskContext = { null } + override def addTaskFailureListener(listener: TaskFailureListener): TaskContext = { null } + override def stageId(): Int = 0 + override def stageAttemptNumber(): Int = 0 + override def partitionId(): Int = 0 + override def attemptNumber(): Int = 0 + override def taskAttemptId(): Long = 0 + override def getLocalProperty(key: String): String = "" + override def taskMetrics(): TaskMetrics = metrics + override def getMetricsSources(sourceName: String): Seq[Source] = Seq.empty + override private[spark] def killTaskIfInterrupted(): Unit = {} + override private[spark] def getKillReason() = None + override private[spark] def taskMemoryManager() = taskMemManager + override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {} + override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {} + override private[spark] def markInterrupted(reason: String): Unit = {} + override private[spark] def markTaskFailed(error: Throwable): Unit = {} + override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {} + override private[spark] def fetchFailed = None + override private[spark] def getLocalProperties = { null } + } + + class ManualCloseFileOutputStream(file: File) extends FileOutputStream(file, true) { + override def close(): Unit = { + flush() + } + + def manualClose(): Unit = { + flush() + super.close() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala new file mode 100644 index 0000000000000..dbd73f2688dfc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { + + private val shuffleHandle: BypassMergeSortShuffleHandle[String, String] = + new BypassMergeSortShuffleHandle[String, String]( + shuffleId = 0, + numMaps = 1, + dependency) + + private val MIN_NUM_ITERS = 10 + private val DATA_SIZE_SMALL = 1000 + private val DATA_SIZE_LARGE = + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/4/DEFAULT_DATA_STRING_SIZE + + def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { + val conf = new SparkConf(loadDefaults = false) + val shuffleWriteSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 7090)) + conf.set("spark.file.transferTo", String.valueOf(transferTo)) + conf.set("spark.shuffle.file.buffer", "32k") + + val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( + blockManager, + shuffleHandle, + 0, + conf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleWriteSupport + ) + + shuffleWriter + } + + def writeBenchmarkWithLargeDataset(): Unit = { + val size = DATA_SIZE_LARGE + val benchmark = new Benchmark( + "BypassMergeSortShuffleWrite with spill", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + + addBenchmarkCase(benchmark, "without transferTo", size, () => getWriter(false)) + addBenchmarkCase(benchmark, "with transferTo", size, () => getWriter(true)) + benchmark.run() + } + + def writeBenchmarkWithSmallDataset(): Unit = { + val size = DATA_SIZE_SMALL + val benchmark = new Benchmark("BypassMergeSortShuffleWrite without spill", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + addBenchmarkCase(benchmark, "small dataset without disk spill", size, () => getWriter(false)) + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("BypassMergeSortShuffleWriter write") { + writeBenchmarkWithSmallDataset() + writeBenchmarkWithLargeDataset() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index fc1422dfaac75..b7ba9291fc23d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.sort import java.io.File -import java.util.UUID +import java.util.{Properties, UUID} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -29,11 +29,15 @@ import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.scalatest.BeforeAndAfterEach +import scala.util.Random import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -48,7 +52,9 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ + private var writeSupport: ShuffleWriteSupport = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) + .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ @@ -104,12 +110,27 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte temporaryFilesCreated += file (blockId, file) }) - when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) => - blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) - } + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + stageAttemptNumber = 0, + partitionId = 0, + taskAttemptId = Random.nextInt(10000), + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + localProperties = new Properties, + metricsSystem = null, + taskMetrics = taskMetrics)) + + writeSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 7090)) } override def afterEach(): Unit = { + TaskContext.unset() try { Utils.deleteRecursively(tempDir) blockIdToFileMap.clear() @@ -122,11 +143,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -142,15 +163,40 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } test("write with some empty partitions") { + val transferConf = conf.clone.set("spark.file.transferTo", "false") + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0, // MapId + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport + ) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + test("write with some empty partitions with transferTo") { def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) writer.write(records) writer.stop( /* success = */ true) @@ -181,11 +227,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) intercept[SparkException] { @@ -203,11 +249,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { @@ -221,5 +267,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } - } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala new file mode 100644 index 0000000000000..26b92e5203b50 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.File +import java.util.UUID + +import org.apache.commons.io.FileUtils +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer.{KryoSerializer, Serializer, SerializerManager} +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter} +import org.apache.spark.storage.{BlockManager, DiskBlockManager, TempShuffleBlockId} +import org.apache.spark.util.Utils + +abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase { + + protected val DEFAULT_DATA_STRING_SIZE = 5 + + // This is only used in the writer constructors, so it's ok to mock + @Mock(answer = RETURNS_SMART_NULLS) protected var dependency: + ShuffleDependency[String, String, String] = _ + // This is only used in the stop() function, so we can safely mock this without affecting perf + @Mock(answer = RETURNS_SMART_NULLS) protected var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEnv: RpcEnv = _ + @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _ + + protected val defaultConf: SparkConf = new SparkConf(loadDefaults = false) + protected val serializer: Serializer = new KryoSerializer(defaultConf) + protected val partitioner: HashPartitioner = new HashPartitioner(10) + protected val serializerManager: SerializerManager = + new SerializerManager(serializer, defaultConf) + protected val shuffleMetrics: TaskMetrics = new TaskMetrics + + protected val tempFilesCreated: ArrayBuffer[File] = new ArrayBuffer[File] + protected val filenameToFile: mutable.Map[String, File] = new mutable.HashMap[String, File] + + class TestDiskBlockManager(tempDir: File) extends DiskBlockManager(defaultConf, false) { + override def getFile(filename: String): File = { + if (filenameToFile.contains(filename)) { + filenameToFile(filename) + } else { + val outputFile = File.createTempFile("shuffle", null, tempDir) + filenameToFile(filename) = outputFile + outputFile + } + } + + override def createTempShuffleBlock(): (TempShuffleBlockId, File) = { + var blockId = new TempShuffleBlockId(UUID.randomUUID()) + val file = getFile(blockId) + tempFilesCreated += file + (blockId, file) + } + } + + class TestBlockManager(tempDir: File, memoryManager: MemoryManager) extends BlockManager("0", + rpcEnv, + null, + serializerManager, + defaultConf, + memoryManager, + null, + null, + null, + null, + 1) { + override val diskBlockManager = new TestDiskBlockManager(tempDir) + override val remoteBlockTempFileManager = null + } + + protected var tempDir: File = _ + + protected var blockManager: BlockManager = _ + protected var blockResolver: IndexShuffleBlockResolver = _ + + protected var memoryManager: TestMemoryManager = _ + protected var taskMemoryManager: TaskMemoryManager = _ + + MockitoAnnotations.initMocks(this) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.shuffleId).thenReturn(0) + when(taskContext.taskMetrics()).thenReturn(shuffleMetrics) + when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef) + + def setup(): Unit = { + TaskContext.setTaskContext(taskContext) + memoryManager = new TestMemoryManager(defaultConf) + memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES) + taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + tempDir = Utils.createTempDir() + blockManager = new TestBlockManager(tempDir, memoryManager) + blockResolver = new IndexShuffleBlockResolver( + defaultConf, + blockManager) + } + + def addBenchmarkCase( + benchmark: Benchmark, + name: String, + size: Int, + writerSupplier: () => ShuffleWriter[String, String], + numSpillFiles: Option[Int] = Option.empty): Unit = { + benchmark.addTimerCase(name) { timer => + setup() + val writer = writerSupplier() + val dataIterator = createDataIterator(size) + try { + timer.startTiming() + writer.write(dataIterator) + timer.stopTiming() + if (numSpillFiles.isDefined) { + assert(tempFilesCreated.length == numSpillFiles.get) + } + } finally { + writer.stop(true) + } + teardown() + } + } + + def teardown(): Unit = { + FileUtils.deleteDirectory(tempDir) + tempFilesCreated.clear() + filenameToFile.clear() + } + + protected class DataIterator (size: Int) + extends Iterator[Product2[String, String]] { + val random = new Random(123) + var count = 0 + override def hasNext: Boolean = { + count < size + } + + override def next(): Product2[String, String] = { + count+=1 + val string = random.alphanumeric.take(DEFAULT_DATA_STRING_SIZE).mkString + (string, string) + } + } + + + def createDataIterator(size: Int): DataIterator = { + new DataIterator(size) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala new file mode 100644 index 0000000000000..b0ff15cb1f790 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import org.mockito.Mockito.when + +import org.apache.spark.{Aggregator, SparkEnv, TaskContext} +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { + + private val shuffleHandle: BaseShuffleHandle[String, String, String] = + new BaseShuffleHandle( + shuffleId = 0, + numMaps = 1, + dependency = dependency) + + private val MIN_NUM_ITERS = 10 + private val DATA_SIZE_SMALL = 1000 + private val DATA_SIZE_LARGE = + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/4/DEFAULT_DATA_STRING_SIZE + + def getWriter(aggregator: Option[Aggregator[String, String, String]], + sorter: Option[Ordering[String]]): SortShuffleWriter[String, String, String] = { + // we need this since SortShuffleWriter uses SparkEnv to get lots of its private vars + SparkEnv.set(new SparkEnv( + "0", + null, + serializer, + null, + serializerManager, + null, + null, + null, + blockManager, + null, + null, + null, + null, + defaultConf + )) + + if (aggregator.isEmpty && sorter.isEmpty) { + when(dependency.mapSideCombine).thenReturn(false) + } else { + when(dependency.mapSideCombine).thenReturn(false) + when(dependency.aggregator).thenReturn(aggregator) + when(dependency.keyOrdering).thenReturn(sorter) + } + + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + TaskContext.setTaskContext(taskContext) + val writeSupport = new DefaultShuffleWriteSupport( + defaultConf, + blockResolver, + BlockManagerId("0", "localhost", 9099)) + + val shuffleWriter = new SortShuffleWriter[String, String, String]( + blockResolver, + shuffleHandle, + 0, + taskContext, + writeSupport) + shuffleWriter + } + + def writeBenchmarkWithSmallDataset(): Unit = { + val size = DATA_SIZE_SMALL + val benchmark = new Benchmark("SortShuffleWriter without spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + addBenchmarkCase(benchmark, + "small dataset without spills", + size, + () => getWriter(Option.empty, Option.empty), + Some(0)) + benchmark.run() + } + + def writeBenchmarkWithSpill(): Unit = { + val size = DATA_SIZE_LARGE + val benchmark = new Benchmark("SortShuffleWriter with spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase(benchmark, + "no map side combine", + size, + () => getWriter(Option.empty, Option.empty), + Some(7)) + + def createCombiner(i: String): String = i + def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j + def mergeCombiners(i: String, j: String): String = + if (Ordering.String.compare(i, j) > 0) i else j + val aggregator = + new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners) + addBenchmarkCase(benchmark, + "with map side aggregation", + size, + () => getWriter(Some(aggregator), Option.empty), + Some(7)) + + val sorter = Ordering.String + addBenchmarkCase(benchmark, + "with map side sort", + size, + () => getWriter(Option.empty, Some(sorter)), + Some(7)) + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("SortShuffleWriter writer") { + writeBenchmarkWithSmallDataset() + writeBenchmarkWithSpill() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala new file mode 100644 index 0000000000000..7066ba8fb44df --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.sort + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId + +/** + * Benchmark to measure performance for aggregate primitives. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/-results.txt". + * }}} + */ +object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { + + private val shuffleHandle: SerializedShuffleHandle[String, String] = + new SerializedShuffleHandle[String, String](0, 0, this.dependency) + + private val MIN_NUM_ITERS = 10 + private val DATA_SIZE_SMALL = 1000 + private val DATA_SIZE_LARGE = + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/2/DEFAULT_DATA_STRING_SIZE + + def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.file.transferTo", String.valueOf(transferTo)) + val shuffleWriteSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 9099)) + + TaskContext.setTaskContext(taskContext) + new UnsafeShuffleWriter[String, String]( + blockManager, + taskMemoryManager, + shuffleHandle, + 0, + taskContext, + conf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleWriteSupport) + } + + def writeBenchmarkWithSmallDataset(): Unit = { + val size = DATA_SIZE_SMALL + val benchmark = new Benchmark("UnsafeShuffleWriter without spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output) + addBenchmarkCase(benchmark, + "small dataset without spills", + size, + () => getWriter(false), + Some(1)) // The single temp file is for the temp index file + benchmark.run() + } + + def writeBenchmarkWithSpill(): Unit = { + val size = DATA_SIZE_LARGE + val benchmark = new Benchmark("UnsafeShuffleWriter with spills", + size, + minNumIters = MIN_NUM_ITERS, + output = output, + outputPerIteration = true) + addBenchmarkCase(benchmark, "without transferTo", size, () => getWriter(false), Some(7)) + addBenchmarkCase(benchmark, "with transferTo", size, () => getWriter(true), Some(7)) + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("UnsafeShuffleWriter write") { + writeBenchmarkWithSmallDataset() + writeBenchmarkWithSpill() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala new file mode 100644 index 0000000000000..1f4ef0f203994 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} +import java.math.BigInteger +import java.nio.ByteBuffer +import java.nio.channels.{Channels, WritableByteChannel} + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} +import org.mockito.Mock +import org.mockito.Mockito.{doAnswer, doNothing, when} +import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.api.shuffle.SupportsTransferTo +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils + +class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _ + + private val NUM_PARTITIONS = 4 + private val D_LEN = 10 + private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map { + p => (1 to D_LEN).map(_ + p).toArray }.toArray + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: DefaultShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir(null, "test") + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + + doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong) + + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete + tmp.renameTo(mergedOutputFile) + } + null + } + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) + mapOutputWriter = new DefaultShuffleMapOutputWriter( + 0, + 0, + NUM_PARTITIONS, + BlockManagerId("0", "localhost", 9099), + shuffleWriteMetrics, + blockResolver, + conf) + } + + private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { + var startOffset = 0L + val result = new Array[Array[Int]](NUM_PARTITIONS) + (0 until NUM_PARTITIONS).foreach { p => + val partitionSize = partitionSizesInMergedFile(p).toInt + lazy val inner = new Array[Int](partitionSize) + lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize) + if (partitionSize > 0) { + val in = new FileInputStream(mergedOutputFile) + in.getChannel.position(startOffset) + val lin = new LimitedInputStream(in, partitionSize) + var nonEmpty = true + var count = 0 + while (nonEmpty) { + try { + val readBit = lin.read() + if (fromByte) { + innerBytebuffer.put(readBit.toByte) + } else { + inner(count) = readBit + } + count += 1 + } catch { + case _: Exception => + nonEmpty = false + } + } + in.close() + } + if (fromByte) { + result(p) = innerBytebuffer.array().sliding(4, 4).map { b => + new BigInteger(b).intValue() + }.toArray + } else { + result(p) = inner + } + startOffset += partitionSize + } + result + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getPartitionWriter(p) + val stream = writer.openStream() + data(p).foreach { i => stream.write(i)} + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + assert(writer.getNumBytesWritten() == D_LEN) + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(false)) + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getPartitionWriter(p) + val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val numBytes = byteBuffer.remaining() + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + val outputTempFileStream = new FileOutputStream(outputTempFile) + Utils.copyStream( + new ByteBufferInputStream(byteBuffer), + outputTempFileStream, + closeStreams = true) + val tempFileInput = new FileInputStream(outputTempFile) + channel.transferFrom(tempFileInput.getChannel, 0L, numBytes) + // Bytes require * 4 + channel.close() + tempFileInput.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } + + test("copyStreams with an outputstream") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getPartitionWriter(p) + val stream = writer.openStream() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val in = new ByteArrayInputStream(byteBuffer.array()) + Utils.copyStream(in, stream, false, false) + in.close() + stream.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } + + test("copyStreamsWithNIO with a channel") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getPartitionWriter(p) + val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val out = new FileOutputStream(tempFile) + out.write(byteBuffer.array()) + out.close() + val in = new FileInputStream(tempFile) + channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining()) + channel.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 3ab2f0bf4596f..63cfc0b70e7b7 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -22,13 +22,12 @@ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.Semaphore -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future - import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock import org.scalatest.PrivateMethodTester +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ @@ -124,7 +123,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, inputStream) = iterator.next() + val inputStream = iterator.next() + val blockId = iterator.getCurrentBlock() // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) @@ -198,11 +198,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.close() // close() first block's input stream + iterator.next().close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2 + val subIter = iterator.next() // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -416,7 +416,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + iterator.next() + val id1 = iterator.getCurrentBlock() assert(id1 === ShuffleBlockId(0, 0, 0)) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) @@ -434,6 +435,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } sem.acquire() + intercept[FetchFailedException] { iterator.next() } } @@ -545,16 +547,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - val (id, st) = iterator.next() - // Check that the test setup is correct -- make sure we have a concatenated stream. - assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) - - val dst = new DataInputStream(st) - for (i <- 1 to 2500) { - assert(i === dst.readInt()) - } - assert(dst.read() === -1) - dst.close() + // Blocks should be returned without exceptions. + iterator.next() + val blockId1 = iterator.getCurrentBlock() + iterator.next() + val blockId2 = iterator.getCurrentBlock() + assert(Set(blockId1, blockId2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") { @@ -611,11 +609,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + iterator.next() + val id1 = iterator.getCurrentBlock() assert(id1 === ShuffleBlockId(0, 0, 0)) - val (id2, _) = iterator.next() + iterator.next() + val id2 = iterator.getCurrentBlock() assert(id2 === ShuffleBlockId(0, 1, 0)) - val (id3, _) = iterator.next() + iterator.next() + val id3 = iterator.getCurrentBlock() assert(id3 === ShuffleBlockId(0, 2, 0)) } diff --git a/dev/run-spark-25299-benchmarks.sh b/dev/run-spark-25299-benchmarks.sh new file mode 100755 index 0000000000000..2d60f9d5a06ec --- /dev/null +++ b/dev/run-spark-25299-benchmarks.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Script to create a binary distribution for easy deploys of Spark. +# The distribution directory defaults to dist/ but can be overridden below. +# The distribution contains fat (assembly) jars that include the Scala library, +# so it is completely self contained. +# It does not contain source or *.class files. + +set -oue pipefail + + +function usage { + echo "Usage: $(basename $0) [-h] [-u]" + echo "" + echo "Runs the perf tests and optionally uploads the results as a comment to a PR" + echo "" + echo " -h help" + echo " -u Upload the perf results as a comment" + # Exit as error for nesting scripts + exit 1 +} + +UPLOAD=false +while getopts "hu" opt; do + case $opt in + h) + usage + exit 0;; + u) + UPLOAD=true;; + esac +done + +echo "Running SPARK-25299 benchmarks" + +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BlockStoreShuffleReaderBenchmark" +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriterBenchmark" +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.SortShuffleWriterBenchmark" +SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.UnsafeShuffleWriterBenchmark" + +SPARK_DIR=`pwd` + +mkdir -p /tmp/artifacts +cp $SPARK_DIR/sql/core/benchmarks/BlockStoreShuffleReaderBenchmark-results.txt /tmp/artifacts/ +cp $SPARK_DIR/sql/core/benchmarks/BypassMergeSortShuffleWriterBenchmark-results.txt /tmp/artifacts/ +cp $SPARK_DIR/sql/core/benchmarks/SortShuffleWriterBenchmark-results.txt /tmp/artifacts/ +cp $SPARK_DIR/sql/core/benchmarks/UnsafeShuffleWriterBenchmark-results.txt /tmp/artifacts/ + +if [ "$UPLOAD" = false ]; then + exit 0 +fi + +IFS= +RESULTS="" +for benchmark_file in /tmp/artifacts/*.txt; do + RESULTS+=$(cat $benchmark_file) + RESULTS+=$'\n\n' +done + +echo $RESULTS +# Get last git message, filter out empty lines, get the last number of the first line. This is the PR number +PULL_REQUEST_NUM=$(git log -1 --pretty=%B | awk NF | awk '{print $NF}' | head -1 | sed 's/(//g' | sed 's/)//g' | sed 's/#//g') + + +USERNAME=svc-spark-25299 +PASSWORD=$SVC_SPARK_25299_PASSWORD +message='{"body": "```' +message+=$'\n' +message+=$RESULTS +message+=$'\n' +json_message=$(echo $message | awk '{printf "%s\\n", $0}') +json_message+='```", "event":"COMMENT"}' +echo "$json_message" > benchmark_results.json + +echo "Sending benchmark requests to PR $PULL_REQUEST_NUM" +curl -XPOST https://${USERNAME}:${PASSWORD}@api.github.com/repos/palantir/spark/pulls/${PULL_REQUEST_NUM}/reviews -d @benchmark_results.json +rm benchmark_results.json diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 079ff25fcb67e..22cfbf506c645 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -156,10 +156,11 @@ class ShuffledRowRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + context.taskMetrics().decorateTempShuffleReadMetrics( + tempMetrics => new SQLShuffleReadMetricsReporter(tempMetrics, metrics)) + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = @@ -168,7 +169,7 @@ class ShuffledRowRDD( shuffledRowPartition.startPreShufflePartitionIndex, shuffledRowPartition.endPreShufflePartitionIndex, context, - sqlMetricsReporter) + tempMetrics) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) }