diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java
new file mode 100644
index 0000000000000..b0aed4d08d387
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.api.shuffle;
+
+import org.apache.spark.annotation.Experimental;
+
+import java.io.Serializable;
+
+/**
+ * Represents metadata about where shuffle blocks were written in a single map task.
+ *
+ * This is optionally returned by shuffle writers. The inner shuffle locations may
+ * be accessed by shuffle readers. Shuffle locations are only necessary when the
+ * location of shuffle blocks needs to be managed by the driver; shuffle plugins
+ * may choose to use an external database or other metadata management systems to
+ * track the locations of shuffle blocks instead.
+ */
+@Experimental
+public interface MapShuffleLocations extends Serializable {
+
+ /**
+ * Get the location for a given shuffle block written by this map task.
+ */
+ ShuffleLocation getLocationForBlock(int reduceId);
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java
new file mode 100644
index 0000000000000..a312831cb6282
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import org.apache.spark.api.java.Optional;
+
+import java.util.Objects;
+
+/**
+ * :: Experimental ::
+ * An object defining the shuffle block and length metadata associated with the block.
+ * @since 3.0.0
+ */
+public class ShuffleBlockInfo {
+ private final int shuffleId;
+ private final int mapId;
+ private final int reduceId;
+ private final long length;
+ private final Optional shuffleLocation;
+
+ public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length,
+ Optional shuffleLocation) {
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.reduceId = reduceId;
+ this.length = length;
+ this.shuffleLocation = shuffleLocation;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+
+ public int getMapId() {
+ return mapId;
+ }
+
+ public int getReduceId() {
+ return reduceId;
+ }
+
+ public long getLength() {
+ return length;
+ }
+
+ public Optional getShuffleLocation() {
+ return shuffleLocation;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other instanceof ShuffleBlockInfo
+ && shuffleId == ((ShuffleBlockInfo) other).shuffleId
+ && mapId == ((ShuffleBlockInfo) other).mapId
+ && reduceId == ((ShuffleBlockInfo) other).reduceId
+ && length == ((ShuffleBlockInfo) other).length
+ && Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java
new file mode 100644
index 0000000000000..dd7c0ac7320cb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * An interface for launching Shuffle related components
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface ShuffleDataIO {
+ String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin.";
+
+ ShuffleDriverComponents driver();
+ ShuffleExecutorComponents executor();
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java
new file mode 100644
index 0000000000000..6a0ec8d44fd4f
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.IOException;
+import java.util.Map;
+
+public interface ShuffleDriverComponents {
+
+ /**
+ * @return additional SparkConf values necessary for the executors.
+ */
+ Map initializeApplication();
+
+ void cleanupApplication() throws IOException;
+
+ void removeShuffleData(int shuffleId, boolean blocking) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java
new file mode 100644
index 0000000000000..a5fa032bf651d
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import org.apache.spark.annotation.Experimental;
+
+import java.util.Map;
+
+/**
+ * :: Experimental ::
+ * An interface for building shuffle support for Executors
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface ShuffleExecutorComponents {
+ void initializeExecutor(String appId, String execId, Map extraConfigs);
+
+ ShuffleWriteSupport writes();
+
+ ShuffleReadSupport reads();
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java
new file mode 100644
index 0000000000000..d06c11b3c01ee
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+/**
+ * Marker interface representing a location of a shuffle block. Implementations of shuffle readers
+ * and writers are expected to cast this down to an implementation-specific representation.
+ */
+public interface ShuffleLocation {}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
new file mode 100644
index 0000000000000..062cf4ff0fba9
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.api.java.Optional;
+
+/**
+ * :: Experimental ::
+ * An interface for creating and managing shuffle partition writers
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface ShuffleMapOutputWriter {
+ ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException;
+
+ Optional commitAllPartitions() throws IOException;
+
+ void abort(Throwable error) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java
new file mode 100644
index 0000000000000..74c928b0b9c8f
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * An interface for giving streams / channels for shuffle writes.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface ShufflePartitionWriter {
+
+ /**
+ * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying
+ * data store.
+ */
+ OutputStream openStream() throws IOException;
+
+ /**
+ * Get the number of bytes written by this writer's stream returned by {@link #openStream()}.
+ */
+ long getNumBytesWritten();
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java
new file mode 100644
index 0000000000000..9cd8fde09064b
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleReadSupport.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import org.apache.spark.annotation.Experimental;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+/**
+ * :: Experimental ::
+ * An interface for reading shuffle records.
+ * @since 3.0.0
+ */
+@Experimental
+public interface ShuffleReadSupport {
+ /**
+ * Returns an underlying {@link Iterable} that will iterate
+ * through shuffle data, given an iterable for the shuffle blocks to fetch.
+ */
+ Iterable getPartitionReaders(Iterable blockMetadata)
+ throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
new file mode 100644
index 0000000000000..7e2b6cf4133fd
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * An interface for deploying a shuffle map output writer
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface ShuffleWriteSupport {
+ ShuffleMapOutputWriter createMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ int numPartitions) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java
new file mode 100644
index 0000000000000..866b61d0bafd9
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * Indicates that partition writers can transfer bytes directly from input byte channels to
+ * output channels that stream data to the underlying shuffle partition storage medium.
+ *
+ * This API is separated out for advanced users because it only needs to be used for
+ * specific low-level optimizations. The idea is that the returned channel can transfer bytes
+ * from the input file channel out to the backing storage system without copying data into
+ * memory.
+ *
+ * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface SupportsTransferTo extends ShufflePartitionWriter {
+
+ /**
+ * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from
+ * input byte channels to the underlying shuffle data store.
+ */
+ TransferrableWritableByteChannel openTransferrableChannel() throws IOException;
+
+ /**
+ * Returns the number of bytes written either by this writer's output stream opened by
+ * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}.
+ */
+ @Override
+ long getNumBytesWritten();
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java
new file mode 100644
index 0000000000000..18234d7c4c944
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * Represents an output byte channel that can copy bytes from input file channels to some
+ * arbitrary storage system.
+ *
+ * This API is provided for advanced users who can transfer bytes from a file channel to
+ * some output sink without copying data into memory. Most users should not need to use
+ * this functionality; this is primarily provided for the built-in shuffle storage backends
+ * that persist shuffle files on local disk.
+ *
+ * For a simpler alternative, see {@link ShufflePartitionWriter}.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface TransferrableWritableByteChannel extends Closeable {
+
+ /**
+ * Copy all bytes from the source readable byte channel into this byte channel.
+ *
+ * @param source File to transfer bytes from. Do not call anything on this channel other than
+ * {@link FileChannel#transferTo(long, long, WritableByteChannel)}.
+ * @param transferStartPosition Start position of the input file to transfer from.
+ * @param numBytesToTransfer Number of bytes to transfer from the given source.
+ */
+ void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer)
+ throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 32b446785a9f0..128b90429209e 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -19,8 +19,10 @@
import java.io.File;
import java.io.FileInputStream;
-import java.io.FileOutputStream;
import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.nio.channels.FileChannel;
import javax.annotation.Nullable;
import scala.None$;
@@ -34,16 +36,22 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.spark.internal.config.package$;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.SupportsTransferTo;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShufflePartitionWriter;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -82,7 +90,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
+ private final ShuffleWriteSupport shuffleWriteSupport;
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
@@ -99,11 +107,11 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
BypassMergeSortShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle handle,
int mapId,
SparkConf conf,
- ShuffleWriteMetricsReporter writeMetrics) {
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleWriteSupport shuffleWriteSupport) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -115,58 +123,67 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
- this.shuffleBlockResolver = shuffleBlockResolver;
+ this.shuffleWriteSupport = shuffleWriteSupport;
}
@Override
public void write(Iterator> records) throws IOException {
assert (partitionWriters == null);
- if (!records.hasNext()) {
- partitionLengths = new long[numPartitions];
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
- return;
- }
- final SerializerInstance serInstance = serializer.newInstance();
- final long openStartTime = System.nanoTime();
- partitionWriters = new DiskBlockObjectWriter[numPartitions];
- partitionWriterSegments = new FileSegment[numPartitions];
- for (int i = 0; i < numPartitions; i++) {
- final Tuple2 tempShuffleBlockIdPlusFile =
- blockManager.diskBlockManager().createTempShuffleBlock();
- final File file = tempShuffleBlockIdPlusFile._2();
- final BlockId blockId = tempShuffleBlockIdPlusFile._1();
- partitionWriters[i] =
- blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
- }
- // Creating the file to write to and creating a disk writer both involve interacting with
- // the disk, and can take a long time in aggregate when we open many files, so should be
- // included in the shuffle write time.
- writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
-
- while (records.hasNext()) {
- final Product2 record = records.next();
- final K key = record._1();
- partitionWriters[partitioner.getPartition(key)].write(key, record._2());
- }
+ ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport
+ .createMapOutputWriter(shuffleId, mapId, numPartitions);
+ try {
+ if (!records.hasNext()) {
+ partitionLengths = new long[numPartitions];
+ Optional blockLocs = mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ blockLocs.orNull(),
+ partitionLengths);
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
+ partitionWriterSegments = new FileSegment[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2 tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
- for (int i = 0; i < numPartitions; i++) {
- try (DiskBlockObjectWriter writer = partitionWriters[i]) {
- partitionWriterSegments[i] = writer.commitAndGet();
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- }
- File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- File tmp = Utils.tempFileWith(output);
- try {
- partitionLengths = writePartitionedFile(tmp);
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
- } finally {
- if (tmp.exists() && !tmp.delete()) {
- logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+ for (int i = 0; i < numPartitions; i++) {
+ try (DiskBlockObjectWriter writer = partitionWriters[i]) {
+ partitionWriterSegments[i] = writer.commitAndGet();
+ }
}
+
+ partitionLengths = writePartitionedData(mapOutputWriter);
+ Optional mapLocations = mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapLocations.orNull(),
+ partitionLengths);
+ } catch (Exception e) {
+ try {
+ mapOutputWriter.abort(e);
+ } catch (Exception e2) {
+ logger.error("Failed to abort the writer after failing to write map output.", e2);
+ }
+ throw e;
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
@VisibleForTesting
@@ -179,37 +196,57 @@ long[] getPartitionLengths() {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] writePartitionedFile(File outputFile) throws IOException {
+ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
return lengths;
}
-
- final FileOutputStream out = new FileOutputStream(outputFile, true);
final long writeStartTime = System.nanoTime();
- boolean threwException = true;
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
+ ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
if (file.exists()) {
- final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
- try {
- lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
- copyThrewException = false;
- } finally {
- Closeables.close(in, copyThrewException);
+ if (transferToEnabled) {
+ FileInputStream in = new FileInputStream(file);
+ TransferrableWritableByteChannel outputChannel = null;
+ try (FileChannel inputChannel = in.getChannel()) {
+ if (writer instanceof SupportsTransferTo) {
+ outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel();
+ } else {
+ // Use default transferrable writable channel anyways in order to have parity with
+ // UnsafeShuffleWriter.
+ outputChannel = new DefaultTransferrableWritableByteChannel(
+ Channels.newChannel(writer.openStream()));
+ }
+ outputChannel.transferFrom(inputChannel, 0L, inputChannel.size());
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ Closeables.close(outputChannel, copyThrewException);
+ }
+ } else {
+ FileInputStream in = new FileInputStream(file);
+ OutputStream outputStream = null;
+ try {
+ outputStream = writer.openStream();
+ Utils.copyStream(in, outputStream, false, false);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ Closeables.close(outputStream, copyThrewException);
+ }
}
if (!file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
}
+ lengths[i] = writer.getNumBytesWritten();
}
- threwException = false;
} finally {
- Closeables.close(out, threwException);
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java
new file mode 100644
index 0000000000000..ffd97c0f26605
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.ShuffleLocation;
+import org.apache.spark.storage.BlockManagerId;
+
+import java.util.Objects;
+
+public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation {
+
+ /**
+ * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be
+ * feasible.
+ */
+ private static final LoadingCache
+ DEFAULT_SHUFFLE_LOCATIONS_CACHE =
+ CacheBuilder.newBuilder()
+ .maximumSize(BlockManagerId.blockManagerIdCacheSize())
+ .build(new CacheLoader() {
+ @Override
+ public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) {
+ return new DefaultMapShuffleLocations(blockManagerId);
+ }
+ });
+
+ private final BlockManagerId location;
+
+ public DefaultMapShuffleLocations(BlockManagerId blockManagerId) {
+ this.location = blockManagerId;
+ }
+
+ public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) {
+ return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId);
+ }
+
+ @Override
+ public ShuffleLocation getLocationForBlock(int reduceId) {
+ return this;
+ }
+
+ public BlockManagerId getBlockManagerId() {
+ return location;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other instanceof DefaultMapShuffleLocations
+ && Objects.equals(((DefaultMapShuffleLocations) other).location, location);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(location);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java
new file mode 100644
index 0000000000000..64ce851e392d2
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
+import org.apache.spark.util.Utils;
+
+/**
+ * This is used when transferTo is enabled but the shuffle plugin hasn't implemented
+ * {@link org.apache.spark.api.shuffle.SupportsTransferTo}.
+ *
+ * This default implementation exists as a convenience to the unsafe shuffle writer and
+ * the bypass merge sort shuffle writers.
+ */
+public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel {
+
+ private final WritableByteChannel delegate;
+
+ public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void transferFrom(
+ FileChannel source, long transferStartPosition, long numBytesToTransfer) {
+ Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer);
+ }
+
+ @Override
+ public void close() throws IOException {
+ delegate.close();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 9d05f03613ce9..5dd0821e10f59 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle.sort;
+import java.nio.channels.Channels;
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
@@ -31,18 +32,22 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
-import com.google.common.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.*;
import org.apache.spark.annotation.Private;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShufflePartitionWriter;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.api.shuffle.SupportsTransferTo;
import org.apache.spark.internal.config.package$;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
-import org.apache.commons.io.output.CloseShieldOutputStream;
-import org.apache.commons.io.output.CountingOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
@@ -50,12 +55,9 @@
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.util.Utils;
@Private
public class UnsafeShuffleWriter extends ShuffleWriter {
@@ -65,15 +67,14 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
@VisibleForTesting
- static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
private final BlockManager blockManager;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetricsReporter writeMetrics;
+ private final ShuffleWriteSupport shuffleWriteSupport;
private final int shuffleId;
private final int mapId;
private final TaskContext taskContext;
@@ -81,7 +82,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private final boolean transferToEnabled;
private final int initialSortBufferSize;
private final int inputBufferSizeInBytes;
- private final int outputBufferSizeInBytes;
@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
@@ -103,27 +103,15 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream
*/
private boolean stopping = false;
- private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream {
-
- CloseAndFlushShieldOutputStream(OutputStream outputStream) {
- super(outputStream);
- }
-
- @Override
- public void flush() {
- // do nothing
- }
- }
-
public UnsafeShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
SerializedShuffleHandle handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf,
- ShuffleWriteMetricsReporter writeMetrics) throws IOException {
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleWriteSupport shuffleWriteSupport) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@@ -132,7 +120,6 @@ public UnsafeShuffleWriter(
" reduce partitions");
}
this.blockManager = blockManager;
- this.shuffleBlockResolver = shuffleBlockResolver;
this.memoryManager = memoryManager;
this.mapId = mapId;
final ShuffleDependency dep = handle.dependency();
@@ -140,6 +127,7 @@ public UnsafeShuffleWriter(
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = writeMetrics;
+ this.shuffleWriteSupport = shuffleWriteSupport;
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
@@ -147,8 +135,6 @@ public UnsafeShuffleWriter(
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.inputBufferSizeInBytes =
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
- this.outputBufferSizeInBytes =
- (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
open();
}
@@ -230,26 +216,33 @@ void closeAndWriteOutput() throws IOException {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
+ final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport
+ .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
final long[] partitionLengths;
- final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- final File tmp = Utils.tempFileWith(output);
+ Optional mapLocations;
try {
try {
- partitionLengths = mergeSpills(spills, tmp);
+ partitionLengths = mergeSpills(spills, mapWriter);
} finally {
for (SpillInfo spill : spills) {
- if (spill.file.exists() && ! spill.file.delete()) {
+ if (spill.file.exists() && !spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
}
}
}
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
- } finally {
- if (tmp.exists() && !tmp.delete()) {
- logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+ mapLocations = mapWriter.commitAllPartitions();
+ } catch (Exception e) {
+ try {
+ mapWriter.abort(e);
+ } catch (Exception innerE) {
+ logger.error("Failed to abort the Map Output Writer", innerE);
}
+ throw e;
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapLocations.orNull(),
+ partitionLengths);
}
@VisibleForTesting
@@ -281,7 +274,8 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
+ private long[] mergeSpills(SpillInfo[] spills,
+ ShuffleMapOutputWriter mapWriter) throws IOException {
final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS());
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
@@ -289,17 +283,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
final boolean fastMergeIsSupported = !compressionEnabled ||
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
+ final int numPartitions = partitioner.numPartitions();
+ long[] partitionLengths = new long[numPartitions];
try {
if (spills.length == 0) {
- new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
- } else if (spills.length == 1) {
- // Here, we don't need to perform any metrics updates because the bytes written to this
- // output file would have already been counted as shuffle bytes written.
- Files.move(spills[0].file, outputFile);
- return spills[0].partitionLengths;
+ return partitionLengths;
} else {
- final long[] partitionLengths;
// There are multiple spills to merge, so none of these spill files' lengths were counted
// towards our shuffle write count or shuffle write time. If we use the slow merge path,
// then the final output file's size won't necessarily be equal to the sum of the spill
@@ -316,14 +305,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter);
} else {
logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null);
}
} else {
logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec);
}
// When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
// in-memory records, we write out the in-memory records to a file but do not count that
@@ -331,13 +320,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- writeMetrics.incBytesWritten(outputFile.length());
return partitionLengths;
}
} catch (IOException e) {
- if (outputFile.exists() && !outputFile.delete()) {
- logger.error("Unable to delete output file {}", outputFile.getPath());
- }
throw e;
}
}
@@ -345,73 +330,71 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
/**
* Merges spill files using Java FileStreams. This code path is typically slower than
* the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[],
- * File)}, and it's mostly used in cases where the IO compression codec does not support
- * concatenation of compressed data, when encryption is enabled, or when users have
- * explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
+ * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec
+ * does not support concatenation of compressed data, when encryption is enabled, or when
+ * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs.
* This code path might also be faster in cases where individual partition size in a spill
* is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small
* disk ios which is inefficient. In those case, Using large buffers for input and output
* files helps reducing the number of disk ios, making the file merging faster.
*
* @param spills the spills to merge.
- * @param outputFile the file to write the merged data to.
+ * @param mapWriter the map output writer to use for output.
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
* @return the partition lengths in the merged file.
*/
private long[] mergeSpillsWithFileStream(
SpillInfo[] spills,
- File outputFile,
+ ShuffleMapOutputWriter mapWriter,
@Nullable CompressionCodec compressionCodec) throws IOException {
- assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];
- final OutputStream bos = new BufferedOutputStream(
- new FileOutputStream(outputFile),
- outputBufferSizeInBytes);
- // Use a counting output stream to avoid having to close the underlying file and ask
- // the file system for its size after each partition is written.
- final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos);
-
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new NioBufferedFileInputStream(
- spills[i].file,
- inputBufferSizeInBytes);
+ spills[i].file,
+ inputBufferSizeInBytes);
}
for (int partition = 0; partition < numPartitions; partition++) {
- final long initialFileLength = mergedFileOutputStream.getByteCount();
- // Shield the underlying output stream from close() and flush() calls, so that we can close
- // the higher level streams to make sure all data is really flushed and internal state is
- // cleaned.
- OutputStream partitionOutput = new CloseAndFlushShieldOutputStream(
- new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
- partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
- if (compressionCodec != null) {
- partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
- }
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- if (partitionLengthInSpill > 0) {
- InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
- partitionLengthInSpill, false);
- try {
- partitionInputStream = blockManager.serializerManager().wrapForEncryption(
- partitionInputStream);
- if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ boolean copyThrewExecption = true;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ OutputStream partitionOutput = null;
+ try {
+ partitionOutput = writer.openStream();
+ partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
+ if (compressionCodec != null) {
+ partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
+ }
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream = null;
+ try {
+ partitionInputStream = new LimitedInputStream(spillInputStreams[i],
+ partitionLengthInSpill, false);
+ partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+ partitionInputStream);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(
+ partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, partitionOutput);
+ } finally {
+ partitionInputStream.close();
}
- ByteStreams.copy(partitionInputStream, partitionOutput);
- } finally {
- partitionInputStream.close();
}
+ copyThrewExecption = false;
}
+ } finally {
+ Closeables.close(partitionOutput, copyThrewExecption);
}
- partitionOutput.flush();
- partitionOutput.close();
- partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
+ long numBytesWritten = writer.getNumBytesWritten();
+ partitionLengths[partition] = numBytesWritten;
+ writeMetrics.incBytesWritten(numBytesWritten);
}
threwException = false;
} finally {
@@ -420,7 +403,6 @@ private long[] mergeSpillsWithFileStream(
for (InputStream stream : spillInputStreams) {
Closeables.close(stream, threwException);
}
- Closeables.close(mergedFileOutputStream, threwException);
}
return partitionLengths;
}
@@ -430,54 +412,49 @@ private long[] mergeSpillsWithFileStream(
* This is only safe when the IO compression codec and serializer support concatenation of
* serialized streams.
*
+ * @param spills the spills to merge.
+ * @param mapWriter the map output writer to use for output.
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
- assert (spills.length >= 2);
+ private long[] mergeSpillsWithTransferTo(
+ SpillInfo[] spills,
+ ShuffleMapOutputWriter mapWriter) throws IOException {
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
- FileChannel mergedFileOutputChannel = null;
boolean threwException = true;
try {
for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
}
- // This file needs to opened in append mode in order to work around a Linux kernel bug that
- // affects transferTo; see SPARK-3948 for more details.
- mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
-
- long bytesWrittenToMergedFile = 0;
for (int partition = 0; partition < numPartitions; partition++) {
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
- final FileChannel spillInputChannel = spillInputChannels[i];
- final long writeStartTime = System.nanoTime();
- Utils.copyFileStreamNIO(
- spillInputChannel,
- mergedFileOutputChannel,
- spillInputChannelPositions[i],
- partitionLengthInSpill);
- spillInputChannelPositions[i] += partitionLengthInSpill;
- writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
- bytesWrittenToMergedFile += partitionLengthInSpill;
- partitionLengths[partition] += partitionLengthInSpill;
+ boolean copyThrewExecption = true;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ TransferrableWritableByteChannel partitionChannel = null;
+ try {
+ partitionChannel = writer instanceof SupportsTransferTo ?
+ ((SupportsTransferTo) writer).openTransferrableChannel()
+ : new DefaultTransferrableWritableByteChannel(
+ Channels.newChannel(writer.openStream()));
+ for (int i = 0; i < spills.length; i++) {
+ long partitionLengthInSpill = 0L;
+ partitionLengthInSpill += spills[i].partitionLengths[partition];
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ partitionChannel.transferFrom(
+ spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill);
+ spillInputChannelPositions[i] += partitionLengthInSpill;
+ writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
+ }
+ copyThrewExecption = false;
+ } finally {
+ Closeables.close(partitionChannel, copyThrewExecption);
}
- }
- // Check the position after transferTo loop to see if it is in the right position and raise an
- // exception if it is incorrect. The position will not be increased to the expected length
- // after calling transferTo in kernel version 2.6.32. This issue is described at
- // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
- if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
- throw new IOException(
- "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
- "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
- " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
- "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
- "to disable this NIO feature."
- );
+ long numBytes = writer.getNumBytesWritten();
+ partitionLengths[partition] = numBytes;
+ writeMetrics.incBytesWritten(numBytes);
}
threwException = false;
} finally {
@@ -487,7 +464,6 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
assert(spillInputChannelPositions[i] == spills[i].file.length());
Closeables.close(spillInputChannels[i], threwException);
}
- Closeables.close(mergedFileOutputChannel, threwException);
}
return partitionLengths;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java
new file mode 100644
index 0000000000000..7c124c1fe68bc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.shuffle.ShuffleDriverComponents;
+import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
+import org.apache.spark.api.shuffle.ShuffleDataIO;
+import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents;
+
+public class DefaultShuffleDataIO implements ShuffleDataIO {
+
+ private final SparkConf sparkConf;
+
+ public DefaultShuffleDataIO(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @Override
+ public ShuffleExecutorComponents executor() {
+ return new DefaultShuffleExecutorComponents(sparkConf);
+ }
+
+ @Override
+ public ShuffleDriverComponents driver() {
+ return new DefaultShuffleDriverComponents();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
new file mode 100644
index 0000000000000..3b5f9670d64d2
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.MapOutputTracker;
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
+import org.apache.spark.api.shuffle.ShuffleReadSupport;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.serializer.SerializerManager;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.io.DefaultShuffleReadSupport;
+import org.apache.spark.storage.BlockManager;
+
+import java.util.Map;
+
+public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents {
+
+ private final SparkConf sparkConf;
+ private BlockManager blockManager;
+ private IndexShuffleBlockResolver blockResolver;
+ private MapOutputTracker mapOutputTracker;
+ private SerializerManager serializerManager;
+
+ public DefaultShuffleExecutorComponents(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @Override
+ public void initializeExecutor(String appId, String execId, Map extraConfigs) {
+ blockManager = SparkEnv.get().blockManager();
+ mapOutputTracker = SparkEnv.get().mapOutputTracker();
+ serializerManager = SparkEnv.get().serializerManager();
+ blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
+ }
+
+ @Override
+ public ShuffleWriteSupport writes() {
+ checkInitialized();
+ return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId());
+ }
+
+ @Override
+ public ShuffleReadSupport reads() {
+ checkInitialized();
+ return new DefaultShuffleReadSupport(blockManager,
+ mapOutputTracker,
+ serializerManager,
+ sparkConf);
+ }
+
+ private void checkInitialized() {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.");
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
new file mode 100644
index 0000000000000..e83db4e4bcef6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShufflePartitionWriter;
+import org.apache.spark.api.shuffle.SupportsTransferTo;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
+import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.storage.BlockManagerId;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.util.Utils;
+
+public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter {
+
+ private static final Logger log =
+ LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class);
+
+ private final int shuffleId;
+ private final int mapId;
+ private final ShuffleWriteMetricsReporter metrics;
+ private final IndexShuffleBlockResolver blockResolver;
+ private final long[] partitionLengths;
+ private final int bufferSize;
+ private int lastPartitionId = -1;
+ private long currChannelPosition;
+ private final BlockManagerId shuffleServerId;
+
+ private final File outputFile;
+ private File outputTempFile;
+ private FileOutputStream outputFileStream;
+ private FileChannel outputFileChannel;
+ private TimeTrackingOutputStream ts;
+ private BufferedOutputStream outputBufferedFileStream;
+
+ public DefaultShuffleMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ int numPartitions,
+ BlockManagerId shuffleServerId,
+ ShuffleWriteMetricsReporter metrics,
+ IndexShuffleBlockResolver blockResolver,
+ SparkConf sparkConf) {
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.shuffleServerId = shuffleServerId;
+ this.metrics = metrics;
+ this.blockResolver = blockResolver;
+ this.bufferSize =
+ (int) (long) sparkConf.get(
+ package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
+ this.partitionLengths = new long[numPartitions];
+ this.outputFile = blockResolver.getDataFile(shuffleId, mapId);
+ this.outputTempFile = null;
+ }
+
+ @Override
+ public ShufflePartitionWriter getPartitionWriter(int partitionId) throws IOException {
+ if (partitionId <= lastPartitionId) {
+ throw new IllegalArgumentException("Partitions should be requested in increasing order.");
+ }
+ lastPartitionId = partitionId;
+ if (outputTempFile == null) {
+ outputTempFile = Utils.tempFileWith(outputFile);
+ }
+ if (outputFileChannel != null) {
+ currChannelPosition = outputFileChannel.position();
+ } else {
+ currChannelPosition = 0L;
+ }
+ return new DefaultShufflePartitionWriter(partitionId);
+ }
+
+ @Override
+ public Optional commitAllPartitions() throws IOException {
+ cleanUp();
+ File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
+ blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+ return Optional.of(DefaultMapShuffleLocations.get(shuffleServerId));
+ }
+
+ @Override
+ public void abort(Throwable error) {
+ try {
+ cleanUp();
+ } catch (Exception e) {
+ log.error("Unable to close appropriate underlying file stream", e);
+ }
+ if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) {
+ log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
+ }
+ }
+
+ private void cleanUp() throws IOException {
+ if (outputBufferedFileStream != null) {
+ outputBufferedFileStream.close();
+ }
+ if (outputFileChannel != null) {
+ outputFileChannel.close();
+ }
+ if (outputFileStream != null) {
+ outputFileStream.close();
+ }
+ }
+
+ private void initStream() throws IOException {
+ if (outputFileStream == null) {
+ outputFileStream = new FileOutputStream(outputTempFile, true);
+ ts = new TimeTrackingOutputStream(metrics, outputFileStream);
+ }
+ if (outputBufferedFileStream == null) {
+ outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize);
+ }
+ }
+
+ private void initChannel() throws IOException {
+ if (outputFileStream == null) {
+ outputFileStream = new FileOutputStream(outputTempFile, true);
+ }
+ if (outputFileChannel == null) {
+ outputFileChannel = outputFileStream.getChannel();
+ }
+ }
+
+ private class DefaultShufflePartitionWriter implements SupportsTransferTo {
+
+ private final int partitionId;
+ private PartitionWriterStream partStream = null;
+ private PartitionWriterChannel partChannel = null;
+
+ private DefaultShufflePartitionWriter(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ @Override
+ public OutputStream openStream() throws IOException {
+ if (partStream == null) {
+ if (outputFileChannel != null) {
+ throw new IllegalStateException("Requested an output channel for a previous write but" +
+ " now an output stream has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initStream();
+ partStream = new PartitionWriterStream(partitionId);
+ }
+ return partStream;
+ }
+
+ @Override
+ public TransferrableWritableByteChannel openTransferrableChannel() throws IOException {
+ if (partChannel == null) {
+ if (partStream != null) {
+ throw new IllegalStateException("Requested an output stream for a previous write but" +
+ " now an output channel has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initChannel();
+ partChannel = new PartitionWriterChannel(partitionId);
+ }
+ return partChannel;
+ }
+
+ @Override
+ public long getNumBytesWritten() {
+ if (partChannel != null) {
+ try {
+ return partChannel.getCount();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ } else if (partStream != null) {
+ return partStream.getCount();
+ } else {
+ // Assume an empty partition if stream and channel are never created
+ return 0;
+ }
+ }
+ }
+
+ private class PartitionWriterStream extends OutputStream {
+ private final int partitionId;
+ private int count = 0;
+ private boolean isClosed = false;
+
+ PartitionWriterStream(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ public int getCount() {
+ return count;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ verifyNotClosed();
+ outputBufferedFileStream.write(b);
+ count++;
+ }
+
+ @Override
+ public void write(byte[] buf, int pos, int length) throws IOException {
+ verifyNotClosed();
+ outputBufferedFileStream.write(buf, pos, length);
+ count += length;
+ }
+
+ @Override
+ public void close() {
+ isClosed = true;
+ partitionLengths[partitionId] = count;
+ }
+
+ private void verifyNotClosed() {
+ if (isClosed) {
+ throw new IllegalStateException("Attempting to write to a closed block output stream.");
+ }
+ }
+ }
+
+ private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel {
+
+ private final int partitionId;
+
+ PartitionWriterChannel(int partitionId) {
+ super(outputFileChannel);
+ this.partitionId = partitionId;
+ }
+
+ public long getCount() throws IOException {
+ long writtenPosition = outputFileChannel.position();
+ return writtenPosition - currChannelPosition;
+ }
+
+ @Override
+ public void close() throws IOException {
+ partitionLengths[partitionId] = getCount();
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
new file mode 100644
index 0000000000000..86f1583495689
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.BlockManagerId;
+
+public class DefaultShuffleWriteSupport implements ShuffleWriteSupport {
+
+ private final SparkConf sparkConf;
+ private final IndexShuffleBlockResolver blockResolver;
+ private final BlockManagerId shuffleServerId;
+
+ public DefaultShuffleWriteSupport(
+ SparkConf sparkConf,
+ IndexShuffleBlockResolver blockResolver,
+ BlockManagerId shuffleServerId) {
+ this.sparkConf = sparkConf;
+ this.blockResolver = blockResolver;
+ this.shuffleServerId = shuffleServerId;
+ }
+
+ @Override
+ public ShuffleMapOutputWriter createMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ int numPartitions) {
+ return new DefaultShuffleMapOutputWriter(
+ shuffleId, mapId, numPartitions, shuffleServerId,
+ TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java
new file mode 100644
index 0000000000000..a3eddc8ec930e
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.lifecycle;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.api.shuffle.ShuffleDriverComponents;
+import org.apache.spark.storage.BlockManagerMaster;
+
+import java.io.IOException;
+import java.util.Map;
+
+public class DefaultShuffleDriverComponents implements ShuffleDriverComponents {
+
+ private BlockManagerMaster blockManagerMaster;
+
+ @Override
+ public Map initializeApplication() {
+ blockManagerMaster = SparkEnv.get().blockManager().master();
+ return ImmutableMap.of();
+ }
+
+ @Override
+ public void cleanupApplication() {
+ // do nothing
+ }
+
+ @Override
+ public void removeShuffleData(int shuffleId, boolean blocking) throws IOException {
+ checkInitialized();
+ blockManagerMaster.removeShuffle(shuffleId, blocking);
+ }
+
+ private void checkInitialized() {
+ if (blockManagerMaster == null) {
+ throw new IllegalStateException("Driver components must be initialized before using");
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index a111a60d1d024..162216bd0c5a0 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled
import scala.collection.JavaConverters._
+import org.apache.spark.api.shuffle.ShuffleDriverComponents
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -58,7 +59,9 @@ private class CleanupTaskWeakReference(
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
-private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
+private[spark] class ContextCleaner(
+ sc: SparkContext,
+ shuffleDriverComponents: ShuffleDriverComponents) extends Logging {
/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
@@ -221,7 +224,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
try {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
- blockManagerMaster.removeShuffle(shuffleId, blocking)
+ shuffleDriverComponents.removeShuffleData(shuffleId, blocking)
listeners.asScala.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
} catch {
@@ -269,7 +272,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}
- private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index c6271d251970c..b7462c350796a 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal
+import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -281,9 +282,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
// For testing
- def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
- getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
+ def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int)
+ : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
+ getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)
}
/**
@@ -295,8 +296,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
+ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]
/**
* Deletes map output status information for the specified shuffle stage.
@@ -645,8 +646,8 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
@@ -682,12 +683,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private val fetching = new HashSet[Int]
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
- override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
- MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(
+ shuffleId, startPartition, endPartition, statuses)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -871,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
- statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
assert (statuses != null)
- val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
+ val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
@@ -883,8 +885,14 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
- splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+ if (status.mapShuffleLocations == null) {
+ splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
+ } else {
+ val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
+ splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, part), size))
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c16984259bc00..ed373d46f3198 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -40,6 +40,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat}
@@ -213,6 +214,7 @@ class SparkContext(config: SparkConf) extends Logging {
private var _shutdownHookRef: AnyRef = _
private var _statusStore: AppStatusStore = _
private var _heartbeater: Heartbeater = _
+ private var _shuffleDriverComponents: ShuffleDriverComponents = _
private var _resources: scala.collection.immutable.Map[String, ResourceInformation] = _
/* ------------------------------------------------------------------------------------- *
@@ -528,6 +530,14 @@ class SparkContext(config: SparkConf) extends Logging {
executorEnvs ++= _conf.getExecutorEnv
executorEnvs("SPARK_USER") = sparkUser
+ val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
+ val maybeIO = Utils.loadExtensions(
+ classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
+ require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
+ _shuffleDriverComponents = maybeIO.head.driver()
+ _shuffleDriverComponents.initializeApplication().asScala.foreach {
+ case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) }
+
// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
// retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640)
_heartbeatReceiver = env.rpcEnv.setupEndpoint(
@@ -596,7 +606,7 @@ class SparkContext(config: SparkConf) extends Logging {
_cleaner =
if (_conf.get(CLEANER_REFERENCE_TRACKING)) {
- Some(new ContextCleaner(this))
+ Some(new ContextCleaner(this, _shuffleDriverComponents))
} else {
None
}
@@ -1952,6 +1962,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
_heartbeater = null
}
+ _shuffleDriverComponents.cleanupApplication()
if (env != null && _heartbeatReceiver != null) {
Utils.tryLogNonFatalError {
env.rpcEnv.stop(_heartbeatReceiver)
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index ea79c7310349d..df30fd5c7f679 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -56,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable {
private val _diskBytesSpilled = new LongAccumulator
private val _peakExecutionMemory = new LongAccumulator
private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)]
+ private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics =
+ Predef.identity[TempShuffleReadMetrics]
/**
* Time taken on the executor to deserialize this task.
@@ -187,11 +189,17 @@ class TaskMetrics private[spark] () extends Serializable {
* be lost.
*/
private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized {
- val readMetrics = new TempShuffleReadMetrics
- tempShuffleReadMetrics += readMetrics
+ val tempShuffleMetrics = new TempShuffleReadMetrics
+ val readMetrics = _decorFunc(tempShuffleMetrics)
+ tempShuffleReadMetrics += tempShuffleMetrics
readMetrics
}
+ private[spark] def decorateTempShuffleReadMetrics(
+ decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized {
+ _decorFunc = decorFunc
+ }
+
/**
* Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`.
* This is expected to be called on executor heartbeat and at the end of a task.
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 8d4910ff9f80e..becb4d5a90d0f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.metrics.GarbageCollectionMetrics
import org.apache.spark.network.shuffle.Constants
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode}
+import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO
import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.Utils
@@ -798,6 +799,12 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val SHUFFLE_IO_PLUGIN_CLASS =
+ ConfigBuilder("spark.shuffle.io.plugin.class")
+ .doc("Name of the class to use for shuffle IO.")
+ .stringConf
+ .createWithDefault(classOf[DefaultShuffleDataIO].getName)
+
private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
ConfigBuilder("spark.shuffle.file.buffer")
.doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 64f0a060a247c..a61f9bd14ef2f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -24,7 +24,9 @@ import scala.collection.mutable
import org.roaringbitmap.RoaringBitmap
import org.apache.spark.SparkEnv
+import org.apache.spark.api.shuffle.MapShuffleLocations
import org.apache.spark.internal.config
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -33,7 +35,16 @@ import org.apache.spark.util.Utils
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
*/
private[spark] sealed trait MapStatus {
- /** Location where this task was run. */
+
+ /**
+ * Locations where this task stored shuffle blocks.
+ *
+ * May be null if the MapOutputTracker is not tracking the location of shuffle blocks, leaving it
+ * up to the implementation of shuffle plugins to do so.
+ */
+ def mapShuffleLocations: MapShuffleLocations
+
+ /** Location where the task was run. */
def location: BlockManagerId
/**
@@ -56,11 +67,31 @@ private[spark] object MapStatus {
.map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
.getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
+ // A temporary concession to the fact that we only expect implementations of shuffle provided by
+ // Spark to be storing shuffle locations in the driver, meaning we want to introduce as little
+ // serialization overhead as possible in such default cases.
+ //
+ // If more similar cases arise, consider adding a serialization API for these shuffle locations.
+ private val DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 0
+ private val NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 1
+
+ /**
+ * Visible for testing.
+ */
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
+ apply(loc, DefaultMapShuffleLocations.get(loc), uncompressedSizes)
+ }
+
+ def apply(
+ loc: BlockManagerId,
+ mapShuffleLocs: MapShuffleLocations,
+ uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(
+ loc, mapShuffleLocs, uncompressedSizes)
} else {
- new CompressedMapStatus(loc, uncompressedSizes)
+ new CompressedMapStatus(
+ loc, mapShuffleLocs, uncompressedSizes)
}
}
@@ -91,41 +122,89 @@ private[spark] object MapStatus {
math.pow(LOG_BASE, compressedSize & 0xFF).toLong
}
}
-}
+ def writeLocations(
+ loc: BlockManagerId,
+ mapShuffleLocs: MapShuffleLocations,
+ out: ObjectOutput): Unit = {
+ if (mapShuffleLocs != null) {
+ out.writeBoolean(true)
+ if (mapShuffleLocs.isInstanceOf[DefaultMapShuffleLocations]
+ && mapShuffleLocs.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId == loc) {
+ out.writeByte(MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID)
+ } else {
+ out.writeByte(MapStatus.NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID)
+ out.writeObject(mapShuffleLocs)
+ }
+ } else {
+ out.writeBoolean(false)
+ }
+ loc.writeExternal(out)
+ }
+
+ def readLocations(in: ObjectInput): (BlockManagerId, MapShuffleLocations) = {
+ if (in.readBoolean()) {
+ val locId = in.readByte()
+ if (locId == MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) {
+ val blockManagerId = BlockManagerId(in)
+ (blockManagerId, DefaultMapShuffleLocations.get(blockManagerId))
+ } else {
+ val mapShuffleLocations = in.readObject().asInstanceOf[MapShuffleLocations]
+ val blockManagerId = BlockManagerId(in)
+ (blockManagerId, mapShuffleLocations)
+ }
+ } else {
+ val blockManagerId = BlockManagerId(in)
+ (blockManagerId, null)
+ }
+ }
+}
/**
* A [[MapStatus]] implementation that tracks the size of each block. Size for each block is
* represented using a single byte.
*
- * @param loc location where the task is being executed.
+ * @param loc Location were the task is being executed.
+ * @param mapShuffleLocs locations where the task stored its shuffle blocks - may be null.
* @param compressedSizes size of the blocks, indexed by reduce partition id.
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
+ private[this] var mapShuffleLocs: MapShuffleLocations,
private[this] var compressedSizes: Array[Byte])
extends MapStatus with Externalizable {
- protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ // For deserialization only
+ protected def this() = this(null, null, null.asInstanceOf[Array[Byte]])
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ def this(
+ loc: BlockManagerId,
+ mapShuffleLocations: MapShuffleLocations,
+ uncompressedSizes: Array[Long]) {
+ this(
+ loc,
+ mapShuffleLocations,
+ uncompressedSizes.map(MapStatus.compressSize))
}
override def location: BlockManagerId = loc
+ override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs
+
override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- loc.writeExternal(out)
+ MapStatus.writeLocations(loc, mapShuffleLocs, out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- loc = BlockManagerId(in)
+ val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in)
+ loc = deserializedLoc
+ mapShuffleLocs = deserializedMapShuffleLocs
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
@@ -138,6 +217,7 @@ private[spark] class CompressedMapStatus(
* plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
+ * @param mapShuffleLocs location where the task stored shuffle blocks - may be null
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
* @param avgSize average size of the non-empty and non-huge blocks
@@ -145,6 +225,7 @@ private[spark] class CompressedMapStatus(
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
+ private[this] var mapShuffleLocs: MapShuffleLocations,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
@@ -155,10 +236,12 @@ private[spark] class HighlyCompressedMapStatus private (
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null) // For deserialization only
+ protected def this() = this(null, null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
+ override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs
+
override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
@@ -172,7 +255,7 @@ private[spark] class HighlyCompressedMapStatus private (
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- loc.writeExternal(out)
+ MapStatus.writeLocations(loc, mapShuffleLocs, out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
out.writeInt(hugeBlockSizes.size)
@@ -183,7 +266,9 @@ private[spark] class HighlyCompressedMapStatus private (
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- loc = BlockManagerId(in)
+ val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in)
+ loc = deserializedLoc
+ mapShuffleLocs = deserializedMapShuffleLocs
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
@@ -199,7 +284,10 @@ private[spark] class HighlyCompressedMapStatus private (
}
private[spark] object HighlyCompressedMapStatus {
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ def apply(
+ loc: BlockManagerId,
+ mapShuffleLocs: MapShuffleLocations,
+ uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -239,7 +327,12 @@ private[spark] object HighlyCompressedMapStatus {
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
- new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizes)
+ new HighlyCompressedMapStatus(
+ loc,
+ mapShuffleLocs,
+ numNonEmptyBlocks,
+ emptyBlocks,
+ avgSize,
+ hugeBlockSizes)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 39691069bf5f6..f886fe7d9e598 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -31,7 +31,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSe
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput}
import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool}
-import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
+import com.esotericsoftware.kryo.serializers.{ExternalizableSerializer, JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.roaringbitmap.RoaringBitmap
@@ -151,6 +151,8 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer())
kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
+ kryo.register(classOf[CompressedMapStatus], new ExternalizableSerializer())
+ kryo.register(classOf[HighlyCompressedMapStatus], new ExternalizableSerializer())
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))
@@ -486,8 +488,6 @@ private[serializer] object KryoSerializer {
private val toRegister: Seq[Class[_]] = Seq(
ByteBuffer.allocate(1).getClass,
classOf[StorageLevel],
- classOf[CompressedMapStatus],
- classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Boolean]],
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index c7843710413dd..530c3694ad1ec 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -17,10 +17,18 @@
package org.apache.spark.shuffle
+import java.io.InputStream
+
+import scala.collection.JavaConverters._
+
import org.apache.spark._
+import org.apache.spark.api.java.Optional
+import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport}
import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
-import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
+import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -34,34 +42,68 @@ private[spark] class BlockStoreShuffleReader[K, C](
endPartition: Int,
context: TaskContext,
readMetrics: ShuffleReadMetricsReporter,
+ shuffleReadSupport: ShuffleReadSupport,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
- blockManager: BlockManager = SparkEnv.get.blockManager,
- mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ sparkConf: SparkConf = SparkEnv.get.conf)
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
+ private val compressionCodec = CompressionCodec.createCodec(sparkConf)
+
+ private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS)
+
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val wrappedStreams = new ShuffleBlockFetcherIterator(
- context,
- blockManager.shuffleClient,
- blockManager,
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
- serializerManager.wrapStream,
- // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
- SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
- SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
- SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
- SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
- SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
- SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
- readMetrics).toCompletionIterator
+ val streamsIterator =
+ shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] {
+ override def iterator: Iterator[ShuffleBlockInfo] = {
+ mapOutputTracker
+ .getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition)
+ .flatMap { shuffleLocationInfo =>
+ shuffleLocationInfo._2.map { blockInfo =>
+ val block = blockInfo._1.asInstanceOf[ShuffleBlockId]
+ new ShuffleBlockInfo(
+ block.shuffleId,
+ block.mapId,
+ block.reduceId,
+ blockInfo._2,
+ Optional.ofNullable(shuffleLocationInfo._1.orNull))
+ }
+ }
+ }
+ }.asJava).iterator()
- val serializerInstance = dep.serializer.newInstance()
+ val retryingWrappedStreams = new Iterator[InputStream] {
+ override def hasNext: Boolean = streamsIterator.hasNext
- // Create a key/value iterator for each stream
- val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
+ override def next(): InputStream = {
+ var returnStream: InputStream = null
+ while (streamsIterator.hasNext && returnStream == null) {
+ if (shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) {
+ // The default implementation checks for corrupt streams, so it will already have
+ // decompressed/decrypted the bytes
+ returnStream = streamsIterator.next()
+ } else {
+ val nextStream = streamsIterator.next()
+ returnStream = if (compressShuffle) {
+ compressionCodec.compressedInputStream(
+ serializerManager.wrapForEncryption(nextStream))
+ } else {
+ serializerManager.wrapForEncryption(nextStream)
+ }
+ }
+ }
+ if (returnStream == null) {
+ throw new IllegalStateException("Expected shuffle reader iterator to return a stream")
+ }
+ returnStream
+ }
+ }
+
+ val serializerInstance = dep.serializer.newInstance()
+ val recordIter = retryingWrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala
new file mode 100644
index 0000000000000..9b9b8508e88aa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.io
+
+import java.io.InputStream
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{MapOutputTracker, SparkConf, TaskContext}
+import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport}
+import org.apache.spark.internal.config
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
+import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
+
+class DefaultShuffleReadSupport(
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker,
+ serializerManager: SerializerManager,
+ conf: SparkConf) extends ShuffleReadSupport {
+
+ private val maxBytesInFlight = conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024
+ private val maxReqsInFlight = conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT)
+ private val maxBlocksInFlightPerAddress =
+ conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
+ private val maxReqSizeShuffleToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
+ private val detectCorrupt = conf.get(config.SHUFFLE_DETECT_CORRUPT)
+
+ override def getPartitionReaders(blockMetadata: java.lang.Iterable[ShuffleBlockInfo]):
+ java.lang.Iterable[InputStream] = {
+
+ val iterableToReturn = if (blockMetadata.asScala.isEmpty) {
+ Iterable.empty
+ } else {
+ val (minReduceId, maxReduceId) = blockMetadata.asScala.map(block => block.getReduceId)
+ .foldLeft(Int.MaxValue, 0) {
+ case ((min, max), elem) => (math.min(min, elem), math.max(max, elem))
+ }
+ val shuffleId = blockMetadata.asScala.head.getShuffleId
+ new ShuffleBlockFetcherIterable(
+ TaskContext.get(),
+ blockManager,
+ serializerManager,
+ maxBytesInFlight,
+ maxReqsInFlight,
+ maxBlocksInFlightPerAddress,
+ maxReqSizeShuffleToMem,
+ detectCorrupt,
+ shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics(),
+ minReduceId,
+ maxReduceId,
+ shuffleId,
+ mapOutputTracker
+ )
+ }
+ iterableToReturn.asJava
+ }
+}
+
+private class ShuffleBlockFetcherIterable(
+ context: TaskContext,
+ blockManager: BlockManager,
+ serializerManager: SerializerManager,
+ maxBytesInFlight: Long,
+ maxReqsInFlight: Int,
+ maxBlocksInFlightPerAddress: Int,
+ maxReqSizeShuffleToMem: Long,
+ detectCorruption: Boolean,
+ shuffleMetrics: ShuffleReadMetricsReporter,
+ minReduceId: Int,
+ maxReduceId: Int,
+ shuffleId: Int,
+ mapOutputTracker: MapOutputTracker) extends Iterable[InputStream] {
+
+ override def iterator: Iterator[InputStream] = {
+ new ShuffleBlockFetcherIterator(
+ context,
+ blockManager.shuffleClient,
+ blockManager,
+ mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1)
+ .map { shuffleLocationInfo =>
+ val defaultShuffleLocation = shuffleLocationInfo._1
+ .get.asInstanceOf[DefaultMapShuffleLocations]
+ (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2)
+ },
+ serializerManager.wrapStream,
+ maxBytesInFlight,
+ maxReqsInFlight,
+ maxBlocksInFlightPerAddress,
+ maxReqSizeShuffleToMem,
+ detectCorruption,
+ shuffleMetrics).toCompletionIterator
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b59fa8e8a3ccd..947753f6b40e8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -19,9 +19,13 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
+import scala.collection.JavaConverters._
+
import org.apache.spark._
-import org.apache.spark.internal.Logging
+import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents}
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.shuffle._
+import org.apache.spark.util.Utils
/**
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
@@ -68,6 +72,8 @@ import org.apache.spark.shuffle._
*/
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+ import SortShuffleManager._
+
if (!conf.getBoolean("spark.shuffle.spill", true)) {
logWarning(
"spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
@@ -79,6 +85,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
*/
private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
@@ -118,7 +126,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
- startPartition, endPartition, context, metrics)
+ startPartition,
+ endPartition,
+ context,
+ metrics,
+ shuffleExecutorComponents.reads())
}
/** Get a writer for a given partition. Called on executors by map tasks. */
@@ -134,23 +146,24 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
- metrics)
+ metrics,
+ shuffleExecutorComponents.writes())
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
env.conf,
- metrics)
+ metrics,
+ shuffleExecutorComponents.writes())
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
- new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
+ new SortShuffleWriter(
+ shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes())
}
}
@@ -205,6 +218,21 @@ private[spark] object SortShuffleManager extends Logging {
true
}
}
+
+ private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+ val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS)
+ val maybeIO = Utils.loadExtensions(
+ classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
+ require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
+ val executorComponents = maybeIO.head.executor()
+ val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX)
+ .toMap
+ executorComponents.initializeExecutor(
+ conf.getAppId,
+ SparkEnv.get.executorId,
+ extraConfigs.asJava)
+ executorComponents
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 16058de8bf3ff..1fcae684b0052 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -18,18 +18,18 @@
package org.apache.spark.shuffle.sort
import org.apache.spark._
+import org.apache.spark.api.shuffle.ShuffleWriteSupport
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
-import org.apache.spark.storage.ShuffleBlockId
-import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
handle: BaseShuffleHandle[K, V, C],
mapId: Int,
- context: TaskContext)
+ context: TaskContext,
+ writeSupport: ShuffleWriteSupport)
extends ShuffleWriter[K, V] with Logging {
private val dep = handle.dependency
@@ -64,18 +64,14 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
- val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
- val tmp = Utils.tempFileWith(output)
- try {
- val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
- shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
- } finally {
- if (tmp.exists() && !tmp.delete()) {
- logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
- }
- }
+ val mapOutputWriter = writeSupport.createMapOutputWriter(
+ dep.shuffleId, mapId, dep.partitioner.numPartitions)
+ val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
+ val mapLocations = mapOutputWriter.commitAllPartitions()
+ mapStatus = MapStatus(
+ blockManager.shuffleServerId,
+ mapLocations.orNull(),
+ partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index d188bdd912e5e..97b99e08d9ca9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -132,12 +132,14 @@ private[spark] object BlockManagerId {
getCachedBlockManagerId(obj)
}
+ val blockManagerIdCacheSize = 10000
+
/**
* The max cache size is hardcoded to 10000, since the size of a BlockManagerId
* object is about 48B, the total memory cost should be below 1MB which is feasible.
*/
val blockManagerIdCache = CacheBuilder.newBuilder()
- .maximumSize(10000)
+ .maximumSize(blockManagerIdCacheSize)
.build(new CacheLoader[BlockManagerId, BlockManagerId]() {
override def load(id: BlockManagerId) = id
})
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 17390f9c60e79..f9f4e3594e4f9 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.PairsWriter
/**
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
@@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
writeMetrics: ShuffleWriteMetricsReporter,
val blockId: BlockId = null)
extends OutputStream
- with Logging {
+ with Logging
+ with PairsWriter {
/**
* Guards against close calls, e.g. from a wrapping stream.
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index c89d5cc971d2a..22fc4da97a5b2 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -76,7 +76,7 @@ final class ShuffleBlockFetcherIterator(
detectCorrupt: Boolean,
detectCorruptUseExtraMemory: Boolean,
shuffleMetrics: ShuffleReadMetricsReporter)
- extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
+ extends Iterator[InputStream] with DownloadFileManager with Logging {
import ShuffleBlockFetcherIterator._
@@ -399,7 +399,7 @@ final class ShuffleBlockFetcherIterator(
*
* Throws a FetchFailedException if the next block could not be fetched.
*/
- override def next(): (BlockId, InputStream) = {
+ override def next(): InputStream = {
if (!hasNext) {
throw new NoSuchElementException()
}
@@ -497,7 +497,6 @@ final class ShuffleBlockFetcherIterator(
in.close()
}
}
-
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
}
@@ -510,6 +509,7 @@ final class ShuffleBlockFetcherIterator(
throw new NoSuchElementException()
}
currentResult = result.asInstanceOf[SuccessFetchResult]
+<<<<<<< HEAD
(currentResult.blockId,
new BufferReleasingInputStream(
input,
@@ -517,10 +517,19 @@ final class ShuffleBlockFetcherIterator(
currentResult.blockId,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
+=======
+ val blockId = currentResult.blockId.asInstanceOf[ShuffleBlockId]
+ new BufferReleasingInputStream(input, this)
+ }
+
+ // for testing only
+ def getCurrentBlock(): ShuffleBlockId = {
+ currentResult.blockId.asInstanceOf[ShuffleBlockId]
+>>>>>>> b35d23845c... [SPARK-25299] shuffle reader API (#523)
}
- def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
- CompletionIterator[(BlockId, InputStream), this.type](this,
+ def toCompletionIterator: Iterator[InputStream] = {
+ CompletionIterator[InputStream, this.type](this,
onCompleteCallback.onComplete(context))
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index bed50865e7be4..87063819d571a 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.{Channels, FileChannel}
+import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.security.SecureRandom
@@ -394,10 +394,14 @@ private[spark] object Utils extends Logging {
def copyFileStreamNIO(
input: FileChannel,
- output: FileChannel,
+ output: WritableByteChannel,
startPosition: Long,
bytesToCopy: Long): Unit = {
- val initialPos = output.position()
+ val outputInitialState = output match {
+ case outputFileChannel: FileChannel =>
+ Some((outputFileChannel.position(), outputFileChannel))
+ case _ => None
+ }
var count = 0L
// In case transferTo method transferred less data than we have required.
while (count < bytesToCopy) {
@@ -412,15 +416,17 @@ private[spark] object Utils extends Logging {
// kernel version 2.6.32, this issue can be seen in
// https://bugs.openjdk.java.net/browse/JDK-7052359
// This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
- val finalPos = output.position()
- val expectedPos = initialPos + bytesToCopy
- assert(finalPos == expectedPos,
- s"""
- |Current position $finalPos do not equal to expected position $expectedPos
- |after transferTo, please check your kernel version to see if it is 2.6.32,
- |this is a kernel bug which will lead to unexpected behavior when using transferTo.
- |You can set spark.file.transferTo = false to disable this NIO feature.
- """.stripMargin)
+ outputInitialState.foreach { case (initialPos, outputFileChannel) =>
+ val finalPos = outputFileChannel.position()
+ val expectedPos = initialPos + bytesToCopy
+ assert(finalPos == expectedPos,
+ s"""
+ |Current position $finalPos do not equal to expected position $expectedPos
+ |after transferTo, please check your kernel version to see if it is 2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO feature.
+ """.stripMargin)
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3f3b7d20eb169..13132c2801ed9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.ByteStreams
import org.apache.spark._
+import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
-import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -670,11 +671,9 @@ private[spark] class ExternalSorter[K, V, C](
}
/**
- * Write all the data added into this ExternalSorter into a file in the disk store. This is
- * called by the SortShuffleWriter.
- *
- * @param blockId block ID to write to. The index file will be blockId.name + ".index".
- * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ * TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project.
+ * We should figure out an alternative way to test that so that we can remove this otherwise
+ * unused code path.
*/
def writePartitionedFile(
blockId: BlockId,
@@ -718,6 +717,88 @@ private[spark] class ExternalSorter[K, V, C](
lengths
}
+ /**
+ * Write all the data added into this ExternalSorter into a map output writer that pushes bytes
+ * to some arbitrary backing store. This is called by the SortShuffleWriter.
+ *
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ */
+ def writePartitionedMapOutput(
+ shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = {
+ // Track location of each range in the map output
+ val lengths = new Array[Long](numPartitions)
+ if (spills.isEmpty) {
+ // Case where we only have in-memory data
+ val collection = if (aggregator.isDefined) map else buffer
+ val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
+ while (it.hasNext()) {
+ val partitionId = it.nextPartition()
+ var partitionWriter: ShufflePartitionWriter = null
+ var partitionPairsWriter: ShufflePartitionPairsWriter = null
+ try {
+ partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
+ val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
+ partitionPairsWriter = new ShufflePartitionPairsWriter(
+ partitionWriter,
+ serializerManager,
+ serInstance,
+ blockId,
+ context.taskMetrics().shuffleWriteMetrics)
+ while (it.hasNext && it.nextPartition() == partitionId) {
+ it.writeNext(partitionPairsWriter)
+ }
+ } finally {
+ if (partitionPairsWriter != null) {
+ partitionPairsWriter.close()
+ }
+ }
+ if (partitionWriter != null) {
+ lengths(partitionId) = partitionWriter.getNumBytesWritten
+ }
+ }
+ } else {
+ // We must perform merge-sort; get an iterator by partition and write everything directly.
+ for ((id, elements) <- this.partitionedIterator) {
+ // The contract for the plugin is that we will ask for a writer for every partition
+ // even if it's empty. However, the external sorter will return non-contiguous
+ // partition ids. So this loop "backfills" the empty partitions that form the gaps.
+
+ // The algorithm as a whole is correct because the partition ids are returned by the
+ // iterator in ascending order.
+ val blockId = ShuffleBlockId(shuffleId, mapId, id)
+ var partitionWriter: ShufflePartitionWriter = null
+ var partitionPairsWriter: ShufflePartitionPairsWriter = null
+ try {
+ partitionWriter = mapOutputWriter.getPartitionWriter(id)
+ partitionPairsWriter = new ShufflePartitionPairsWriter(
+ partitionWriter,
+ serializerManager,
+ serInstance,
+ blockId,
+ context.taskMetrics().shuffleWriteMetrics)
+ if (elements.hasNext) {
+ for (elem <- elements) {
+ partitionPairsWriter.write(elem._1, elem._2)
+ }
+ }
+ } finally {
+ if (partitionPairsWriter!= null) {
+ partitionPairsWriter.close()
+ }
+ }
+ if (partitionWriter != null) {
+ lengths(id) = partitionWriter.getNumBytesWritten
+ }
+ }
+ }
+
+ context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
+
+ lengths
+ }
+
def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
@@ -781,7 +862,7 @@ private[spark] class ExternalSorter[K, V, C](
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (upstream.hasNext) upstream.next() else null
- def writeNext(writer: DiskBlockObjectWriter): Unit = {
+ def writeNext(writer: PairsWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (upstream.hasNext) upstream.next() else null
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala
new file mode 100644
index 0000000000000..9d7c209f242e1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+private[spark] trait PairsWriter {
+
+ def write(key: Any, value: Any): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala
new file mode 100644
index 0000000000000..8538a78b377c8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.{Closeable, FilterOutputStream, OutputStream}
+
+import org.apache.spark.api.shuffle.ShufflePartitionWriter
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
+import org.apache.spark.storage.BlockId
+
+/**
+ * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
+ * arbitrary partition writer instead of writing to local disk through the block manager.
+ */
+private[spark] class ShufflePartitionPairsWriter(
+ partitionWriter: ShufflePartitionWriter,
+ serializerManager: SerializerManager,
+ serializerInstance: SerializerInstance,
+ blockId: BlockId,
+ writeMetrics: ShuffleWriteMetricsReporter)
+ extends PairsWriter with Closeable {
+
+ private var isOpen = false
+ private var partitionStream: OutputStream = _
+ private var wrappedStream: OutputStream = _
+ private var objOut: SerializationStream = _
+ private var numRecordsWritten = 0
+ private var curNumBytesWritten = 0L
+
+ override def write(key: Any, value: Any): Unit = {
+ if (!isOpen) {
+ open()
+ isOpen = true
+ }
+ objOut.writeKey(key)
+ objOut.writeValue(value)
+ writeMetrics.incRecordsWritten(1)
+ }
+
+ private def open(): Unit = {
+ partitionStream = partitionWriter.openStream
+ wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
+ objOut = serializerInstance.serializeStream(wrappedStream)
+ }
+
+ override def close(): Unit = {
+ if (isOpen) {
+ objOut.close()
+ objOut = null
+ wrappedStream = null
+ partitionStream = null
+ isOpen = false
+ updateBytesWritten()
+ }
+ }
+
+ /**
+ * Notify the writer that a record worth of bytes has been written with OutputStream#write.
+ */
+ private def recordWritten(): Unit = {
+ numRecordsWritten += 1
+ writeMetrics.incRecordsWritten(1)
+
+ if (numRecordsWritten % 16384 == 0) {
+ updateBytesWritten()
+ }
+ }
+
+ private def updateBytesWritten(): Unit = {
+ val numBytesWritten = partitionWriter.getNumBytesWritten
+ val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
+ writeMetrics.incBytesWritten(bytesWrittenDiff)
+ curNumBytesWritten = numBytesWritten
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index dd7f68fd038d2..da8d58d05b6b9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null
- def writeNext(writer: DiskBlockObjectWriter): Unit = {
+ def writeNext(writer: PairsWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
@@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection {
* has an associated partition.
*/
private[spark] trait WritablePartitionedIterator {
- def writeNext(writer: DiskBlockObjectWriter): Unit
+ def writeNext(writer: PairsWriter): Unit
def hasNext(): Boolean
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 88125a6b93ade..3c172a027ca0f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -19,8 +19,10 @@
import java.io.*;
import java.nio.ByteBuffer;
+import java.nio.file.Files;
import java.util.*;
+import org.mockito.stubbing.Answer;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
@@ -39,6 +41,7 @@
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
+import org.apache.spark.TaskContext$;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.io.CompressionCodec$;
@@ -53,6 +56,7 @@
import org.apache.spark.security.CryptoStreamUtils;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -65,6 +69,7 @@
public class UnsafeShuffleWriterSuite {
+ static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int NUM_PARTITITONS = 4;
TestMemoryManager memoryManager;
TaskMemoryManager taskMemoryManager;
@@ -85,6 +90,7 @@ public class UnsafeShuffleWriterSuite {
@After
public void tearDown() {
+ TaskContext$.MODULE$.unset();
Utils.deleteRecursively(tempDir);
final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
if (leakedMemory != 0) {
@@ -132,14 +138,28 @@ public void setUp() throws IOException {
});
when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
- doAnswer(invocationOnMock -> {
+
+ Answer renameTempAnswer = invocationOnMock -> {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
File tmp = (File) invocationOnMock.getArguments()[3];
- mergedOutputFile.delete();
- tmp.renameTo(mergedOutputFile);
+ if (!mergedOutputFile.delete()) {
+ throw new RuntimeException("Failed to delete old merged output file.");
+ }
+ if (tmp != null) {
+ Files.move(tmp.toPath(), mergedOutputFile.toPath());
+ } else if (!mergedOutputFile.createNewFile()) {
+ throw new RuntimeException("Failed to create empty merged output file.");
+ }
return null;
- }).when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
+ };
+
+ doAnswer(renameTempAnswer)
+ .when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
+
+ doAnswer(renameTempAnswer)
+ .when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
@@ -151,6 +171,11 @@ public void setUp() throws IOException {
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
+ when(blockManager.shuffleServerId()).thenReturn(BlockManagerId.apply(
+ "0", "localhost", 9099, Option.empty()));
+
+ TaskContext$.MODULE$.setTaskContext(taskContext);
}
private UnsafeShuffleWriter createWriter(
@@ -158,14 +183,13 @@ private UnsafeShuffleWriter createWriter(
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
return new UnsafeShuffleWriter<>(
blockManager,
- shuffleBlockResolver,
- taskMemoryManager,
+ taskMemoryManager,
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics()
- );
+ taskContext.taskMetrics().shuffleWriteMetrics(),
+ new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId()));
}
private void assertSpillFilesWereCleanedUp() {
@@ -444,10 +468,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro
}
private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
+ memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
final UnsafeShuffleWriter writer = createWriter(false);
final ArrayList> dataToWrite = new ArrayList<>();
- for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
+ for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
dataToWrite.add(new Tuple2<>(i, i));
}
writer.write(dataToWrite.iterator());
@@ -519,13 +543,13 @@ public void testPeakMemoryUsed() throws Exception {
final UnsafeShuffleWriter writer =
new UnsafeShuffleWriter<>(
blockManager,
- shuffleBlockResolver,
- taskMemoryManager,
+ taskMemoryManager,
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics());
+ taskContext.taskMetrics().shuffleWriteMetrics(),
+ new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId()));
// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index 62824a5bec9d1..28cbeeda7a88d 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -210,7 +210,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
/**
* A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
*/
- private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
+ private class SaveAccumContextCleaner(sc: SparkContext) extends
+ ContextCleaner(sc, null) {
private val accumsRegistered = new ArrayBuffer[Long]
override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index d86975964b558..8fcbc845d1a7b 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MA
import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
class MapOutputTrackerSuite extends SparkFunSuite {
@@ -67,10 +68,13 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Array(1000L, 10000L)))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L)))
- val statuses = tracker.getMapSizesByExecutorId(10, 0)
+ val statuses = tracker.getMapSizesByShuffleLocation(10, 0)
assert(statuses.toSet ===
- Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
- (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
+ Seq(
+ (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))),
+ ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))),
+ (Some(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))),
+ ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000))))
.toSet)
assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.stop()
@@ -90,11 +94,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
assert(tracker.containsShuffle(10))
- assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
+ assert(tracker.getMapSizesByShuffleLocation(10, 0).nonEmpty)
assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
- assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
+ assert(tracker.getMapSizesByShuffleLocation(10, 0).isEmpty)
tracker.stop()
rpcEnv.shutdown()
@@ -121,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) }
+ intercept[FetchFailedException] { tracker.getMapSizesByShuffleLocation(10, 1) }
tracker.stop()
rpcEnv.shutdown()
@@ -143,24 +147,26 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerShuffle(10, 1)
slaveTracker.updateEpoch(masterTracker.getEpoch)
// This is expected to fail because no outputs have been registered for the shuffle.
- intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) }
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("a", "hostA", 1000), Array(1000L)))
slaveTracker.updateEpoch(masterTracker.getEpoch)
- assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
+ assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq ===
+ Seq(
+ (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))),
+ ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput)
slaveTracker.updateEpoch(masterTracker.getEpoch)
- intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) }
// failure should be cached
- intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) }
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
masterTracker.stop()
@@ -261,8 +267,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// being sent.
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
+ val bmId = BlockManagerId("999", "mps", 1000)
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
- BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
+ bmId,
+ DefaultMapShuffleLocations.get(bmId),
+ Array.fill[Long](4000000)(0)))
}
val senderAddress = RpcAddress("localhost", 12345)
val rpcCallContext = mock(classOf[RpcCallContext])
@@ -315,12 +324,13 @@ class MapOutputTrackerSuite extends SparkFunSuite {
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(size10000, size0, size1000, size0)))
assert(tracker.containsShuffle(10))
- assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq ===
+ assert(tracker.getMapSizesByShuffleLocation(10, 0, 4)
+ .map(x => (x._1.get, x._2)).toSeq ===
Seq(
- (BlockManagerId("a", "hostA", 1000),
- Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))),
- (BlockManagerId("b", "hostB", 1000),
- Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000)))
+ (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)),
+ Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))),
+ (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)),
+ Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000)))
)
)
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 8b1084a8edc76..1d2713151f505 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId}
-import org.apache.spark.util.{MutablePair, Utils}
+import org.apache.spark.util.MutablePair
abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext {
@@ -73,7 +73,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// All blocks must have non-zero size
(0 until NUM_BLOCKS).foreach { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id)
assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0)))
}
}
@@ -112,7 +112,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id)
statuses.flatMap(_._2.map(_._2))
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
@@ -137,7 +137,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
- val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id)
+ val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id)
statuses.flatMap(_._2.map(_._2))
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
@@ -368,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)
val writer1 = manager.getWriter[Int, Int](
shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics)
- val data1 = (1 to 10).map { x => x -> x}
+ val data1 = (1 to 10).map { x => x -> x }
// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
@@ -383,13 +383,17 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// simultaneously, and everything is still OK
def writeAndClose(
- writer: ShuffleWriter[Int, Int])(
+ writer: ShuffleWriter[Int, Int],
+ taskContext: TaskContext)(
iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+ TaskContext.setTaskContext(taskContext)
val files = writer.write(iter)
- writer.stop(true)
+ val status = writer.stop(true)
+ TaskContext.unset
+ status
}
val interleaver = new InterleaveIterators(
- data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+ data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2))
val (mapOutput1, mapOutput2) = interleaver.run()
// check that we can read the map output and it has the right data
@@ -405,12 +409,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
val taskContext = new TaskContextImpl(
1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)
+ TaskContext.setTaskContext(taskContext)
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics)
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
manager.unregisterShuffle(0)
+ TaskContext.unset()
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index d58ee4e651e19..51a65cc2a6ba8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -29,12 +29,14 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.api.shuffle.MapShuffleLocations
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.executor.ExecutorMetrics
import org.apache.spark.internal.config
import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException}
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils}
@@ -713,8 +715,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB")))
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()
@@ -740,8 +742,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length))))
// we can see both result blocks now
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
- HashSet("hostA", "hostB"))
+ assert(mapOutputTracker
+ .getMapSizesByShuffleLocation(shuffleId, 0)
+ .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host)
+ .toSet === HashSet("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
assertDataStructuresEmpty()
@@ -779,11 +783,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(ExecutorLost("exec-hostA", event))
if (expectFileLoss) {
intercept[MetadataFetchFailedException] {
- mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0)
+ mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0)
}
} else {
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB")))
}
}
}
@@ -1076,8 +1080,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
(Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
(Success, makeMapStatus("hostB", reduceRdd.partitions.length))))
// The MapOutputTracker should know about both map output locations.
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
- HashSet("hostA", "hostB"))
+ assert(mapOutputTracker
+ .getMapSizesByShuffleLocation(shuffleId, 0)
+ .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host)
+ .toSet === HashSet("hostA", "hostB"))
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(makeCompletionEvent(
@@ -1206,10 +1212,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// The MapOutputTracker should know about both map output locations.
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet ===
- HashSet("hostA", "hostB"))
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet ===
- HashSet("hostA", "hostB"))
+ assert(mapOutputTracker
+ .getMapSizesByShuffleLocation(shuffleId, 0)
+ .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host)
+ .toSet === HashSet("hostA", "hostB"))
+ assert(mapOutputTracker
+ .getMapSizesByShuffleLocation(shuffleId, 1)
+ .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host)
+ .toSet === HashSet("hostA", "hostB"))
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(makeCompletionEvent(
@@ -1399,8 +1409,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
Success,
makeMapStatus("hostA", reduceRdd.partitions.size)))
assert(shuffleStage.numAvailableOutputs === 2)
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostA")))
// finish the next stage normally, which completes the job
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
@@ -1554,7 +1564,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
reduceIdx <- reduceIdxs
} {
// this would throw an exception if the map status hadn't been registered
- val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx)
+ val statuses = mapOutputTracker.getMapSizesByShuffleLocation(stage, reduceIdx)
// really we should have already thrown an exception rather than fail either of these
// asserts, but just to be extra defensive let's double check the statuses are OK
assert(statuses != null)
@@ -1606,7 +1616,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// check that we have all the map output for stage 0
(0 until reduceRdd.partitions.length).foreach { reduceIdx =>
- val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx)
+ val statuses = mapOutputTracker.getMapSizesByShuffleLocation(0, reduceIdx)
// really we should have already thrown an exception rather than fail either of these
// asserts, but just to be extra defensive let's double check the statuses are OK
assert(statuses != null)
@@ -1805,8 +1815,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// have hostC complete the resubmitted task
complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB")))
// Make sure that the reduce stage was now submitted.
assert(taskSets.size === 3)
@@ -2068,8 +2078,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(reduceRdd, Array(0))
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1))))
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostA")))
// Reducer should run on the same host that map task ran
val reduceTaskSet = taskSets(1)
@@ -2114,8 +2124,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(reduceRdd, Array(0))
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1))))
- assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostA")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostA")))
// Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep
val reduceTaskSet = taskSets(1)
@@ -2278,8 +2288,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", rdd1.partitions.length)),
(Success, makeMapStatus("hostB", rdd1.partitions.length))))
- assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB")))
assert(listener1.results.size === 1)
// When attempting the second stage, show a fetch failure
@@ -2294,8 +2304,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(taskSets(2).stageId === 0)
complete(taskSets(2), Seq(
(Success, makeMapStatus("hostC", rdd2.partitions.length))))
- assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB")))
+
assert(listener2.results.size === 0) // Second stage listener should still not have a result
// Stage 1 should now be running as task set 3; make its first task succeed
@@ -2303,8 +2314,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
complete(taskSets(3), Seq(
(Success, makeMapStatus("hostB", rdd2.partitions.length)),
(Success, makeMapStatus("hostD", rdd2.partitions.length))))
- assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostD")))
assert(listener2.results.size === 1)
// Finally, the reduce job should be running as task set 4; make it see a fetch failure,
@@ -2342,8 +2353,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", rdd1.partitions.length)),
(Success, makeMapStatus("hostB", rdd1.partitions.length))))
- assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
- HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB")))
assert(listener1.results.size === 1)
// When attempting stage1, trigger a fetch failure.
@@ -2368,8 +2379,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(taskSets(2).stageId === 0)
complete(taskSets(2), Seq(
(Success, makeMapStatus("hostC", rdd2.partitions.length))))
- assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
- Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet ===
+ Set(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB")))
// After stage0 is finished, stage1 will be submitted and found there is no missing
// partitions in it. Then listener got triggered.
@@ -2982,6 +2993,14 @@ object DAGSchedulerSuite {
def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
+
+ def makeShuffleLocation(host: String): MapShuffleLocations = {
+ DefaultMapShuffleLocations.get(makeBlockManagerId(host))
+ }
+
+ def makeMaybeShuffleLocation(host: String): Option[MapShuffleLocations] = {
+ Some(DefaultMapShuffleLocations.get(makeBlockManagerId(host)))
+ }
}
object FailThisAttempt {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index c1e7fb9a1db16..3c786c0927bc6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark.LocalSparkContext._
import org.apache.spark.internal.config
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.BlockManagerId
class MapStatusSuite extends SparkFunSuite {
@@ -61,7 +62,11 @@ class MapStatusSuite extends SparkFunSuite {
stddev <- Seq(0.0, 0.01, 0.5, 1.0)
) {
val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean)
- val status = MapStatus(BlockManagerId("a", "b", 10), sizes)
+ val bmId = BlockManagerId("a", "b", 10)
+ val status = MapStatus(
+ bmId,
+ DefaultMapShuffleLocations.get(bmId),
+ sizes)
val status1 = compressAndDecompressMapStatus(status)
for (i <- 0 until numSizes) {
if (sizes(i) != 0) {
@@ -75,7 +80,7 @@ class MapStatusSuite extends SparkFunSuite {
test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) {
val sizes = Array.fill[Long](2001)(150L)
- val status = MapStatus(null, sizes)
+ val status = MapStatus(null, null, sizes)
assert(status.isInstanceOf[HighlyCompressedMapStatus])
assert(status.getSizeForBlock(10) === 150L)
assert(status.getSizeForBlock(50) === 150L)
@@ -86,11 +91,13 @@ class MapStatusSuite extends SparkFunSuite {
test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") {
val sizes = Array.tabulate[Long](3000) { i => i.toLong }
val avg = sizes.sum / sizes.count(_ != 0)
- val loc = BlockManagerId("a", "b", 10)
- val status = MapStatus(loc, sizes)
+ val bmId = BlockManagerId("a", "b", 10)
+ val loc = DefaultMapShuffleLocations.get(bmId)
+ val status = MapStatus(bmId, loc, sizes)
val status1 = compressAndDecompressMapStatus(status)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
- assert(status1.location == loc)
+ assert(status1.location == loc.getBlockManagerId)
+ assert(status1.mapShuffleLocations == loc)
for (i <- 0 until 3000) {
val estimate = status1.getSizeForBlock(i)
if (sizes(i) > 0) {
@@ -108,11 +115,13 @@ class MapStatusSuite extends SparkFunSuite {
val sizes = (0L to 3000L).toArray
val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold)
val avg = smallBlockSizes.sum / smallBlockSizes.length
- val loc = BlockManagerId("a", "b", 10)
- val status = MapStatus(loc, sizes)
+ val bmId = BlockManagerId("a", "b", 10)
+ val loc = DefaultMapShuffleLocations.get(bmId)
+ val status = MapStatus(bmId, loc, sizes)
val status1 = compressAndDecompressMapStatus(status)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
- assert(status1.location == loc)
+ assert(status1.location === bmId)
+ assert(status1.mapShuffleLocations === loc)
for (i <- 0 until threshold) {
val estimate = status1.getSizeForBlock(i)
if (sizes(i) > 0) {
@@ -165,7 +174,8 @@ class MapStatusSuite extends SparkFunSuite {
SparkEnv.set(env)
// Value of element in sizes is equal to the corresponding index.
val sizes = (0L to 2000L).toArray
- val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)
+ val bmId = BlockManagerId("exec-0", "host-0", 100)
+ val status1 = MapStatus(bmId, DefaultMapShuffleLocations.get(bmId), sizes)
val arrayStream = new ByteArrayOutputStream(102400)
val objectOutputStream = new ObjectOutputStream(arrayStream)
assert(status1.isInstanceOf[HighlyCompressedMapStatus])
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 577d77e890d78..ec828db9391da 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -191,7 +191,8 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
shuffleId <- shuffleIds
reduceIdx <- (0 until nParts)
} {
- val statuses = taskScheduler.mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceIdx)
+ val statuses = taskScheduler.mapOutputTracker.getMapSizesByShuffleLocation(
+ shuffleId, reduceIdx)
// really we should have already thrown an exception rather than fail either of these
// asserts, but just to be extra defensive let's double check the statuses are OK
assert(statuses != null)
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 2442670b6d3f0..63f6942dee184 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -36,8 +36,9 @@ import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Kryo._
import org.apache.spark.scheduler.HighlyCompressedMapStatus
import org.apache.spark.serializer.KryoTest._
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.ThreadUtils
class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
@@ -350,8 +351,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
val ser = new KryoSerializer(conf).newInstance()
val denseBlockSizes = new Array[Long](5000)
val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L)
+ val bmId = BlockManagerId("exec-1", "host", 1234)
Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes =>
- ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes))
+ ser.serialize(HighlyCompressedMapStatus(
+ bmId, DefaultMapShuffleLocations.get(bmId), blockSizes))
}
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index 6d2ef17a7a790..6468914bf3185 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -21,12 +21,19 @@ import java.io.{ByteArrayOutputStream, InputStream}
import java.nio.ByteBuffer
import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
import org.apache.spark._
+import org.apache.spark.api.shuffle.ShuffleLocation
import org.apache.spark.internal.config
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
+import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.BlockId
/**
* Wrapper for a managed buffer that keeps track of how many times retain and release are called.
@@ -78,11 +85,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Create a buffer with some randomly generated key-value pairs to use as the shuffle data
// from each mappers (all mappers return the same shuffle data).
val byteOutputStream = new ByteArrayOutputStream()
- val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
+ val compressionCodec = CompressionCodec.createCodec(testConf)
+ val compressedOutputStream = compressionCodec.compressedOutputStream(byteOutputStream)
+ val serializationStream = serializer.newInstance().serializeStream(compressedOutputStream)
(0 until keyValuePairsPerMap).foreach { i =>
serializationStream.writeKey(i)
serializationStream.writeValue(2*i)
}
+ compressedOutputStream.close()
// Setup the mocked BlockManager to return RecordingManagedBuffers.
val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
@@ -101,16 +111,20 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
- val mapOutputTracker = mock(classOf[MapOutputTracker])
- when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn {
- // Test a scenario where all data is local, to avoid creating a bunch of additional mocks
- // for the code to read data over the network.
- val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
- val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
- (shuffleBlockId, byteOutputStream.size().toLong)
- }
- Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator
+ val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
+ val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+ (shuffleBlockId, byteOutputStream.size().toLong)
}
+ val blocksToRetrieve = Seq(
+ (Option.apply(DefaultMapShuffleLocations.get(localBlockManagerId)), shuffleBlockIdsAndSizes))
+ val mapOutputTracker = mock(classOf[MapOutputTracker])
+ when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1))
+ .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] {
+ def answer(invocationOnMock: InvocationOnMock):
+ Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
+ blocksToRetrieve.iterator
+ }
+ })
// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
@@ -124,19 +138,23 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
val serializerManager = new SerializerManager(
serializer,
new SparkConf()
- .set(config.SHUFFLE_COMPRESS, false)
+ .set(config.SHUFFLE_COMPRESS, true)
.set(config.SHUFFLE_SPILL_COMPRESS, false))
val taskContext = TaskContext.empty()
+ TaskContext.setTaskContext(taskContext)
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
+
+ val shuffleReadSupport =
+ new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, testConf)
val shuffleReader = new BlockStoreShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
taskContext,
metrics,
+ shuffleReadSupport,
serializerManager,
- blockManager,
mapOutputTracker)
assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
@@ -147,5 +165,6 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
assert(buffer.callsToRetain === 1)
assert(buffer.callsToRelease === 1)
}
+ TaskContext.unset()
}
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala
new file mode 100644
index 0000000000000..e8372c0458600
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.util
+
+import com.google.common.collect.ImmutableMap
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
+import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport}
+import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS
+import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport
+
+class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext {
+ test(s"test serialization of shuffle initialization conf to executors") {
+ val testConf = new SparkConf()
+ .setAppName("testing")
+ .setMaster("local-cluster[2,1,1024]")
+ .set(SHUFFLE_IO_PLUGIN_CLASS, "org.apache.spark.shuffle.TestShuffleDataIO")
+
+ sc = new SparkContext(testConf)
+
+ sc.parallelize(Seq((1, "one"), (2, "two"), (3, "three")), 3)
+ .groupByKey()
+ .collect()
+ }
+}
+
+class TestShuffleDriverComponents extends ShuffleDriverComponents {
+ override def initializeApplication(): util.Map[String, String] =
+ ImmutableMap.of("test-key", "test-value")
+
+ override def cleanupApplication(): Unit = {}
+
+ override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = {}
+}
+
+class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO {
+ override def driver(): ShuffleDriverComponents = new TestShuffleDriverComponents()
+
+ override def executor(): ShuffleExecutorComponents =
+ new TestShuffleExecutorComponents(sparkConf)
+}
+
+class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents {
+ override def initializeExecutor(appId: String, execId: String,
+ extraConfigs: util.Map[String, String]): Unit = {
+ assert(extraConfigs.get("test-key") == "test-value")
+ }
+
+ override def writes(): ShuffleWriteSupport = {
+ val blockManager = SparkEnv.get.blockManager
+ val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager)
+ new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId)
+ }
+
+ override def reads(): ShuffleReadSupport = {
+ val blockManager = SparkEnv.get.blockManager
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker
+ val serializerManager = SparkEnv.get.serializerManager
+ new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala
new file mode 100644
index 0000000000000..4f5bb264170de
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle.sort
+
+import java.io.{File, FileOutputStream}
+
+import com.google.common.io.CountingOutputStream
+import org.apache.commons.io.FileUtils
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.when
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import scala.util.Random
+
+import org.apache.spark.{Aggregator, MapOutputTracker, ShuffleDependency, SparkConf, SparkEnv, TaskContext}
+import org.apache.spark.api.shuffle.ShuffleLocation
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.metrics.source.Source
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf}
+import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
+import org.apache.spark.shuffle.{BaseShuffleHandle, BlockStoreShuffleReader, FetchFailedException}
+import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockManagerMaster, ShuffleBlockId}
+import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener, Utils}
+
+/**
+ * Benchmark to measure performance for aggregate primitives.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/-results.txt".
+ * }}}
+ */
+object BlockStoreShuffleReaderBenchmark extends BenchmarkBase {
+
+ // this is only used to retrieve the aggregator/sorters/serializers,
+ // so it shouldn't affect the performance significantly
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency:
+ ShuffleDependency[String, String, String] = _
+ // only used to retrieve info about the maps at the beginning, doesn't affect perf
+ @Mock(answer = RETURNS_SMART_NULLS) private var mapOutputTracker: MapOutputTracker = _
+ // this is only used when initializing the BlockManager, so doesn't affect perf
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockManagerMaster: BlockManagerMaster = _
+ // this is only used when initiating the BlockManager, for comms between master and executor
+ @Mock(answer = RETURNS_SMART_NULLS) private var rpcEnv: RpcEnv = _
+ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _
+
+ private var tempDir: File = _
+
+ private val NUM_MAPS = 5
+ private val DEFAULT_DATA_STRING_SIZE = 5
+ private val TEST_DATA_SIZE = 10000000
+ private val SMALLER_DATA_SIZE = 2000000
+ private val MIN_NUM_ITERS = 10
+
+ private val executorId = "0"
+ private val localPort = 17000
+ private val remotePort = 17002
+
+ private val defaultConf = new SparkConf()
+ .set("spark.shuffle.compress", "false")
+ .set("spark.shuffle.spill.compress", "false")
+ .set("spark.authenticate", "false")
+ .set("spark.app.id", "test-app")
+ private val serializer = new KryoSerializer(defaultConf)
+ private val serializerManager = new SerializerManager(serializer, defaultConf)
+ private val execBlockManagerId = BlockManagerId(executorId, "localhost", localPort)
+ private val remoteBlockManagerId = BlockManagerId(executorId, "localhost", remotePort)
+ private val transportConf = SparkTransportConf.fromSparkConf(defaultConf, "shuffle")
+ private val securityManager = new org.apache.spark.SecurityManager(defaultConf)
+ protected val memoryManager = new TestMemoryManager(defaultConf)
+
+ class TestBlockManager(transferService: BlockTransferService,
+ blockManagerMaster: BlockManagerMaster,
+ dataFile: File,
+ fileLength: Long,
+ offset: Long) extends BlockManager(
+ executorId,
+ rpcEnv,
+ blockManagerMaster,
+ serializerManager,
+ defaultConf,
+ memoryManager,
+ null,
+ null,
+ transferService,
+ null,
+ 1) {
+ blockManagerId = execBlockManagerId
+
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ new FileSegmentManagedBuffer(
+ transportConf,
+ dataFile,
+ offset,
+ fileLength
+ )
+ }
+ }
+
+ private var blockManager : BlockManager = _
+ private var externalBlockManager: BlockManager = _
+
+ def getTestBlockManager(
+ port: Int,
+ dataFile: File,
+ dataFileLength: Long,
+ offset: Long): TestBlockManager = {
+ val shuffleClient = new NettyBlockTransferService(
+ defaultConf,
+ securityManager,
+ "localhost",
+ "localhost",
+ port,
+ 1
+ )
+ new TestBlockManager(shuffleClient,
+ blockManagerMaster,
+ dataFile,
+ dataFileLength,
+ offset)
+ }
+
+ def initializeServers(dataFile: File, dataFileLength: Long, readOffset: Long = 0): Unit = {
+ MockitoAnnotations.initMocks(this)
+ when(blockManagerMaster.registerBlockManager(
+ any[BlockManagerId], any[Long], any[Long], any[RpcEndpointRef])).thenReturn(null)
+ when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef)
+ blockManager = getTestBlockManager(localPort, dataFile, dataFileLength, readOffset)
+ blockManager.initialize(defaultConf.getAppId)
+ externalBlockManager = getTestBlockManager(remotePort, dataFile, dataFileLength, readOffset)
+ externalBlockManager.initialize(defaultConf.getAppId)
+ }
+
+ def stopServers(): Unit = {
+ blockManager.stop()
+ externalBlockManager.stop()
+ }
+
+ def setupReader(
+ dataFile: File,
+ dataFileLength: Long,
+ fetchLocal: Boolean,
+ aggregator: Option[Aggregator[String, String, String]] = None,
+ sorter: Option[Ordering[String]] = None): BlockStoreShuffleReader[String, String] = {
+ SparkEnv.set(new SparkEnv(
+ "0",
+ null,
+ serializer,
+ null,
+ serializerManager,
+ mapOutputTracker,
+ null,
+ null,
+ blockManager,
+ null,
+ null,
+ null,
+ null,
+ defaultConf
+ ))
+
+ val shuffleHandle = new BaseShuffleHandle(
+ shuffleId = 0,
+ numMaps = NUM_MAPS,
+ dependency = dependency)
+
+ val taskContext = new TestTaskContext
+ TaskContext.setTaskContext(taskContext)
+
+ var dataBlockId = execBlockManagerId
+ if (!fetchLocal) {
+ dataBlockId = remoteBlockManagerId
+ }
+
+ when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1))
+ .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] {
+ def answer(invocationOnMock: InvocationOnMock):
+ Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
+ val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId =>
+ val shuffleBlockId = ShuffleBlockId(0, mapId, 0)
+ (shuffleBlockId, dataFileLength)
+ }
+ Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes))
+ .toIterator
+ }
+ })
+
+ when(dependency.serializer).thenReturn(serializer)
+ when(dependency.aggregator).thenReturn(aggregator)
+ when(dependency.keyOrdering).thenReturn(sorter)
+
+ val readSupport = new DefaultShuffleReadSupport(
+ blockManager,
+ mapOutputTracker,
+ serializerManager,
+ defaultConf)
+
+ new BlockStoreShuffleReader[String, String](
+ shuffleHandle,
+ 0,
+ 1,
+ taskContext,
+ taskContext.taskMetrics().createTempShuffleReadMetrics(),
+ readSupport,
+ serializerManager,
+ mapOutputTracker
+ )
+ }
+
+ def generateDataOnDisk(size: Int, file: File, recordOffset: Int): (Long, Long) = {
+ // scalastyle:off println
+ println("Generating test data with num records: " + size)
+
+ val dataOutput = new ManualCloseFileOutputStream(file)
+ val random = new Random(123)
+ val serializerInstance = serializer.newInstance()
+
+ var countingOutput = new CountingOutputStream(dataOutput)
+ var serializedOutput = serializerInstance.serializeStream(countingOutput)
+ var readOffset = 0L
+ try {
+ (1 to size).foreach { i => {
+ if (i % 1000000 == 0) {
+ println("Wrote " + i + " test data points")
+ }
+ if (i == recordOffset) {
+ serializedOutput.close()
+ readOffset = countingOutput.getCount
+ countingOutput = new CountingOutputStream(dataOutput)
+ serializedOutput = serializerInstance.serializeStream(countingOutput)
+ }
+ val x = random.alphanumeric.take(DEFAULT_DATA_STRING_SIZE).mkString
+ serializedOutput.writeKey(x)
+ serializedOutput.writeValue(x)
+ }}
+ } finally {
+ serializedOutput.close()
+ dataOutput.manualClose()
+ }
+ (countingOutput.getCount, readOffset)
+ // scalastyle:off println
+ }
+
+ class TestDataFile(file: File, length: Long, offset: Long) {
+ def getFile(): File = file
+ def getLength(): Long = length
+ def getOffset(): Long = offset
+ }
+
+ def runWithTestDataFile(size: Int, readOffset: Int = 0)(func: TestDataFile => Unit): Unit = {
+ val tempDataFile = File.createTempFile("test-data", "", tempDir)
+ val dataFileLengthAndOffset = generateDataOnDisk(size, tempDataFile, readOffset)
+ initializeServers(tempDataFile, dataFileLengthAndOffset._1, dataFileLengthAndOffset._2)
+ func(new TestDataFile(tempDataFile, dataFileLengthAndOffset._1, dataFileLengthAndOffset._2))
+ tempDataFile.delete()
+ stopServers()
+ }
+
+ def addBenchmarkCase(
+ benchmark: Benchmark,
+ name: String,
+ shuffleReaderSupplier: => BlockStoreShuffleReader[String, String],
+ assertSize: Option[Int] = None): Unit = {
+ benchmark.addTimerCase(name) { timer =>
+ val reader = shuffleReaderSupplier
+ timer.startTiming()
+ val numRead = reader.read().length
+ timer.stopTiming()
+ assertSize.foreach(size => assert(numRead == size))
+ }
+ }
+
+ def runLargeDatasetTests(): Unit = {
+ runWithTestDataFile(TEST_DATA_SIZE) { testDataFile =>
+ val baseBenchmark =
+ new Benchmark("no aggregation or sorting",
+ TEST_DATA_SIZE,
+ minNumIters = MIN_NUM_ITERS,
+ output = output,
+ outputPerIteration = true)
+ addBenchmarkCase(
+ baseBenchmark,
+ "local fetch",
+ setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = true),
+ assertSize = Option.apply(TEST_DATA_SIZE * NUM_MAPS))
+ addBenchmarkCase(
+ baseBenchmark,
+ "remote rpc fetch",
+ setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = false),
+ assertSize = Option.apply(TEST_DATA_SIZE * NUM_MAPS))
+ baseBenchmark.run()
+ }
+ }
+
+ def runSmallDatasetTests(): Unit = {
+ runWithTestDataFile(SMALLER_DATA_SIZE) { testDataFile =>
+ def createCombiner(i: String): String = i
+ def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j
+ def mergeCombiners(i: String, j: String): String =
+ if (Ordering.String.compare(i, j) > 0) i else j
+ val aggregator =
+ new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners)
+ val aggregationBenchmark =
+ new Benchmark("with aggregation",
+ SMALLER_DATA_SIZE,
+ minNumIters = MIN_NUM_ITERS,
+ output = output,
+ outputPerIteration = true)
+ addBenchmarkCase(
+ aggregationBenchmark,
+ "local fetch",
+ setupReader(
+ testDataFile.getFile(),
+ testDataFile.getLength(),
+ fetchLocal = true,
+ aggregator = Some(aggregator)))
+ addBenchmarkCase(
+ aggregationBenchmark,
+ "remote rpc fetch",
+ setupReader(
+ testDataFile.getFile(),
+ testDataFile.getLength(),
+ fetchLocal = false,
+ aggregator = Some(aggregator)))
+ aggregationBenchmark.run()
+
+
+ val sortingBenchmark =
+ new Benchmark("with sorting",
+ SMALLER_DATA_SIZE,
+ minNumIters = MIN_NUM_ITERS,
+ output = output,
+ outputPerIteration = true)
+ addBenchmarkCase(
+ sortingBenchmark,
+ "local fetch",
+ setupReader(
+ testDataFile.getFile(),
+ testDataFile.getLength(),
+ fetchLocal = true,
+ sorter = Some(Ordering.String)),
+ assertSize = Option.apply(SMALLER_DATA_SIZE * NUM_MAPS))
+ addBenchmarkCase(
+ sortingBenchmark,
+ "remote rpc fetch",
+ setupReader(
+ testDataFile.getFile(),
+ testDataFile.getLength(),
+ fetchLocal = false,
+ sorter = Some(Ordering.String)),
+ assertSize = Option.apply(SMALLER_DATA_SIZE * NUM_MAPS))
+ sortingBenchmark.run()
+ }
+ }
+
+ def runSeekTests(): Unit = {
+ runWithTestDataFile(SMALLER_DATA_SIZE, readOffset = SMALLER_DATA_SIZE) { testDataFile =>
+ val seekBenchmark =
+ new Benchmark("with seek",
+ SMALLER_DATA_SIZE,
+ minNumIters = MIN_NUM_ITERS,
+ output = output)
+
+ addBenchmarkCase(
+ seekBenchmark,
+ "seek to last record",
+ setupReader(testDataFile.getFile(), testDataFile.getLength(), fetchLocal = false),
+ Option.apply(NUM_MAPS))
+ seekBenchmark.run()
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ tempDir = Utils.createTempDir(null, "shuffle")
+
+ runBenchmark("BlockStoreShuffleReader reader") {
+ runLargeDatasetTests()
+ runSmallDatasetTests()
+ runSeekTests()
+ }
+
+ FileUtils.deleteDirectory(tempDir)
+ }
+
+ // We cannot mock the TaskContext because it taskMetrics() gets called at every next()
+ // call on the reader, and Mockito will try to log all calls to taskMetrics(), thus OOM-ing
+ // the test
+ class TestTaskContext extends TaskContext {
+ private val metrics: TaskMetrics = new TaskMetrics
+ private val testMemManager = new TestMemoryManager(defaultConf)
+ private val taskMemManager = new TaskMemoryManager(testMemManager, 0)
+ testMemManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES)
+ override def isCompleted(): Boolean = false
+ override def isInterrupted(): Boolean = false
+ override def addTaskCompletionListener(listener: TaskCompletionListener):
+ TaskContext = { null }
+ override def addTaskFailureListener(listener: TaskFailureListener): TaskContext = { null }
+ override def stageId(): Int = 0
+ override def stageAttemptNumber(): Int = 0
+ override def partitionId(): Int = 0
+ override def attemptNumber(): Int = 0
+ override def taskAttemptId(): Long = 0
+ override def getLocalProperty(key: String): String = ""
+ override def taskMetrics(): TaskMetrics = metrics
+ override def getMetricsSources(sourceName: String): Seq[Source] = Seq.empty
+ override private[spark] def killTaskIfInterrupted(): Unit = {}
+ override private[spark] def getKillReason() = None
+ override private[spark] def taskMemoryManager() = taskMemManager
+ override private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {}
+ override private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit = {}
+ override private[spark] def markInterrupted(reason: String): Unit = {}
+ override private[spark] def markTaskFailed(error: Throwable): Unit = {}
+ override private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = {}
+ override private[spark] def fetchFailed = None
+ override private[spark] def getLocalProperties = { null }
+ }
+
+ class ManualCloseFileOutputStream(file: File) extends FileOutputStream(file, true) {
+ override def close(): Unit = {
+ flush()
+ }
+
+ def manualClose(): Unit = {
+ flush()
+ super.close()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala
new file mode 100644
index 0000000000000..dbd73f2688dfc
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.apache.spark.SparkConf
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport
+import org.apache.spark.storage.BlockManagerId
+
+/**
+ * Benchmark to measure performance for aggregate primitives.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/-results.txt".
+ * }}}
+ */
+object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase {
+
+ private val shuffleHandle: BypassMergeSortShuffleHandle[String, String] =
+ new BypassMergeSortShuffleHandle[String, String](
+ shuffleId = 0,
+ numMaps = 1,
+ dependency)
+
+ private val MIN_NUM_ITERS = 10
+ private val DATA_SIZE_SMALL = 1000
+ private val DATA_SIZE_LARGE =
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/4/DEFAULT_DATA_STRING_SIZE
+
+ def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = {
+ val conf = new SparkConf(loadDefaults = false)
+ val shuffleWriteSupport = new DefaultShuffleWriteSupport(
+ conf, blockResolver, BlockManagerId("0", "localhost", 7090))
+ conf.set("spark.file.transferTo", String.valueOf(transferTo))
+ conf.set("spark.shuffle.file.buffer", "32k")
+
+ val shuffleWriter = new BypassMergeSortShuffleWriter[String, String](
+ blockManager,
+ shuffleHandle,
+ 0,
+ conf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleWriteSupport
+ )
+
+ shuffleWriter
+ }
+
+ def writeBenchmarkWithLargeDataset(): Unit = {
+ val size = DATA_SIZE_LARGE
+ val benchmark = new Benchmark(
+ "BypassMergeSortShuffleWrite with spill",
+ size,
+ minNumIters = MIN_NUM_ITERS,
+ output = output)
+
+ addBenchmarkCase(benchmark, "without transferTo", size, () => getWriter(false))
+ addBenchmarkCase(benchmark, "with transferTo", size, () => getWriter(true))
+ benchmark.run()
+ }
+
+ def writeBenchmarkWithSmallDataset(): Unit = {
+ val size = DATA_SIZE_SMALL
+ val benchmark = new Benchmark("BypassMergeSortShuffleWrite without spill",
+ size,
+ minNumIters = MIN_NUM_ITERS,
+ output = output)
+ addBenchmarkCase(benchmark, "small dataset without disk spill", size, () => getWriter(false))
+ benchmark.run()
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("BypassMergeSortShuffleWriter write") {
+ writeBenchmarkWithSmallDataset()
+ writeBenchmarkWithLargeDataset()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index fc1422dfaac75..b7ba9291fc23d 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.shuffle.sort
import java.io.File
-import java.util.UUID
+import java.util.{Properties, UUID}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -29,11 +29,15 @@ import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.scalatest.BeforeAndAfterEach
+import scala.util.Random
import org.apache.spark._
+import org.apache.spark.api.shuffle.ShuffleWriteSupport
import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -48,7 +52,9 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
private var taskMetrics: TaskMetrics = _
private var tempDir: File = _
private var outputFile: File = _
+ private var writeSupport: ShuffleWriteSupport = _
private val conf: SparkConf = new SparkConf(loadDefaults = false)
+ .set("spark.app.id", "sampleApp")
private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
@@ -104,12 +110,27 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
temporaryFilesCreated += file
(blockId, file)
})
- when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) =>
- blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
- }
+ val memoryManager = new TestMemoryManager(conf)
+ val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ stageAttemptNumber = 0,
+ partitionId = 0,
+ taskAttemptId = Random.nextInt(10000),
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
+ metricsSystem = null,
+ taskMetrics = taskMetrics))
+
+ writeSupport = new DefaultShuffleWriteSupport(
+ conf, blockResolver, BlockManagerId("0", "localhost", 7090))
}
override def afterEach(): Unit = {
+ TaskContext.unset()
try {
Utils.deleteRecursively(tempDir)
blockIdToFileMap.clear()
@@ -122,11 +143,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("write empty iterator") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
writer.write(Iterator.empty)
writer.stop( /* success = */ true)
@@ -142,15 +163,40 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
}
test("write with some empty partitions") {
+ val transferConf = conf.clone.set("spark.file.transferTo", "false")
+ def records: Iterator[(Int, Int)] =
+ Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ blockManager,
+ shuffleHandle,
+ 0, // MapId
+ transferConf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
+ )
+ writer.write(records)
+ writer.stop( /* success = */ true)
+ assert(temporaryFilesCreated.nonEmpty)
+ assert(writer.getPartitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
+ assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.recordsWritten === records.length)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ test("write with some empty partitions with transferTo") {
def records: Iterator[(Int, Int)] =
Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
writer.write(records)
writer.stop( /* success = */ true)
@@ -181,11 +227,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
intercept[SparkException] {
@@ -203,11 +249,11 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("cleanup of intermediate files after errors") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
intercept[SparkException] {
writer.write((0 until 100000).iterator.map(i => {
@@ -221,5 +267,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}
-
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala
new file mode 100644
index 0000000000000..26b92e5203b50
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleWriterBenchmarkBase.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.File
+import java.util.UUID
+
+import org.apache.commons.io.FileUtils
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.when
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkConf, TaskContext}
+import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.serializer.{KryoSerializer, Serializer, SerializerManager}
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter}
+import org.apache.spark.storage.{BlockManager, DiskBlockManager, TempShuffleBlockId}
+import org.apache.spark.util.Utils
+
+abstract class ShuffleWriterBenchmarkBase extends BenchmarkBase {
+
+ protected val DEFAULT_DATA_STRING_SIZE = 5
+
+ // This is only used in the writer constructors, so it's ok to mock
+ @Mock(answer = RETURNS_SMART_NULLS) protected var dependency:
+ ShuffleDependency[String, String, String] = _
+ // This is only used in the stop() function, so we can safely mock this without affecting perf
+ @Mock(answer = RETURNS_SMART_NULLS) protected var taskContext: TaskContext = _
+ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEnv: RpcEnv = _
+ @Mock(answer = RETURNS_SMART_NULLS) protected var rpcEndpointRef: RpcEndpointRef = _
+
+ protected val defaultConf: SparkConf = new SparkConf(loadDefaults = false)
+ protected val serializer: Serializer = new KryoSerializer(defaultConf)
+ protected val partitioner: HashPartitioner = new HashPartitioner(10)
+ protected val serializerManager: SerializerManager =
+ new SerializerManager(serializer, defaultConf)
+ protected val shuffleMetrics: TaskMetrics = new TaskMetrics
+
+ protected val tempFilesCreated: ArrayBuffer[File] = new ArrayBuffer[File]
+ protected val filenameToFile: mutable.Map[String, File] = new mutable.HashMap[String, File]
+
+ class TestDiskBlockManager(tempDir: File) extends DiskBlockManager(defaultConf, false) {
+ override def getFile(filename: String): File = {
+ if (filenameToFile.contains(filename)) {
+ filenameToFile(filename)
+ } else {
+ val outputFile = File.createTempFile("shuffle", null, tempDir)
+ filenameToFile(filename) = outputFile
+ outputFile
+ }
+ }
+
+ override def createTempShuffleBlock(): (TempShuffleBlockId, File) = {
+ var blockId = new TempShuffleBlockId(UUID.randomUUID())
+ val file = getFile(blockId)
+ tempFilesCreated += file
+ (blockId, file)
+ }
+ }
+
+ class TestBlockManager(tempDir: File, memoryManager: MemoryManager) extends BlockManager("0",
+ rpcEnv,
+ null,
+ serializerManager,
+ defaultConf,
+ memoryManager,
+ null,
+ null,
+ null,
+ null,
+ 1) {
+ override val diskBlockManager = new TestDiskBlockManager(tempDir)
+ override val remoteBlockTempFileManager = null
+ }
+
+ protected var tempDir: File = _
+
+ protected var blockManager: BlockManager = _
+ protected var blockResolver: IndexShuffleBlockResolver = _
+
+ protected var memoryManager: TestMemoryManager = _
+ protected var taskMemoryManager: TaskMemoryManager = _
+
+ MockitoAnnotations.initMocks(this)
+ when(dependency.partitioner).thenReturn(partitioner)
+ when(dependency.serializer).thenReturn(serializer)
+ when(dependency.shuffleId).thenReturn(0)
+ when(taskContext.taskMetrics()).thenReturn(shuffleMetrics)
+ when(rpcEnv.setupEndpoint(any[String], any[RpcEndpoint])).thenReturn(rpcEndpointRef)
+
+ def setup(): Unit = {
+ TaskContext.setTaskContext(taskContext)
+ memoryManager = new TestMemoryManager(defaultConf)
+ memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES)
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+ tempDir = Utils.createTempDir()
+ blockManager = new TestBlockManager(tempDir, memoryManager)
+ blockResolver = new IndexShuffleBlockResolver(
+ defaultConf,
+ blockManager)
+ }
+
+ def addBenchmarkCase(
+ benchmark: Benchmark,
+ name: String,
+ size: Int,
+ writerSupplier: () => ShuffleWriter[String, String],
+ numSpillFiles: Option[Int] = Option.empty): Unit = {
+ benchmark.addTimerCase(name) { timer =>
+ setup()
+ val writer = writerSupplier()
+ val dataIterator = createDataIterator(size)
+ try {
+ timer.startTiming()
+ writer.write(dataIterator)
+ timer.stopTiming()
+ if (numSpillFiles.isDefined) {
+ assert(tempFilesCreated.length == numSpillFiles.get)
+ }
+ } finally {
+ writer.stop(true)
+ }
+ teardown()
+ }
+ }
+
+ def teardown(): Unit = {
+ FileUtils.deleteDirectory(tempDir)
+ tempFilesCreated.clear()
+ filenameToFile.clear()
+ }
+
+ protected class DataIterator (size: Int)
+ extends Iterator[Product2[String, String]] {
+ val random = new Random(123)
+ var count = 0
+ override def hasNext: Boolean = {
+ count < size
+ }
+
+ override def next(): Product2[String, String] = {
+ count+=1
+ val string = random.alphanumeric.take(DEFAULT_DATA_STRING_SIZE).mkString
+ (string, string)
+ }
+ }
+
+
+ def createDataIterator(size: Int): DataIterator = {
+ new DataIterator(size)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala
new file mode 100644
index 0000000000000..b0ff15cb1f790
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito.when
+
+import org.apache.spark.{Aggregator, SparkEnv, TaskContext}
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport
+import org.apache.spark.storage.BlockManagerId
+
+/**
+ * Benchmark to measure performance for aggregate primitives.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/-results.txt".
+ * }}}
+ */
+object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase {
+
+ private val shuffleHandle: BaseShuffleHandle[String, String, String] =
+ new BaseShuffleHandle(
+ shuffleId = 0,
+ numMaps = 1,
+ dependency = dependency)
+
+ private val MIN_NUM_ITERS = 10
+ private val DATA_SIZE_SMALL = 1000
+ private val DATA_SIZE_LARGE =
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/4/DEFAULT_DATA_STRING_SIZE
+
+ def getWriter(aggregator: Option[Aggregator[String, String, String]],
+ sorter: Option[Ordering[String]]): SortShuffleWriter[String, String, String] = {
+ // we need this since SortShuffleWriter uses SparkEnv to get lots of its private vars
+ SparkEnv.set(new SparkEnv(
+ "0",
+ null,
+ serializer,
+ null,
+ serializerManager,
+ null,
+ null,
+ null,
+ blockManager,
+ null,
+ null,
+ null,
+ null,
+ defaultConf
+ ))
+
+ if (aggregator.isEmpty && sorter.isEmpty) {
+ when(dependency.mapSideCombine).thenReturn(false)
+ } else {
+ when(dependency.mapSideCombine).thenReturn(false)
+ when(dependency.aggregator).thenReturn(aggregator)
+ when(dependency.keyOrdering).thenReturn(sorter)
+ }
+
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
+ TaskContext.setTaskContext(taskContext)
+ val writeSupport = new DefaultShuffleWriteSupport(
+ defaultConf,
+ blockResolver,
+ BlockManagerId("0", "localhost", 9099))
+
+ val shuffleWriter = new SortShuffleWriter[String, String, String](
+ blockResolver,
+ shuffleHandle,
+ 0,
+ taskContext,
+ writeSupport)
+ shuffleWriter
+ }
+
+ def writeBenchmarkWithSmallDataset(): Unit = {
+ val size = DATA_SIZE_SMALL
+ val benchmark = new Benchmark("SortShuffleWriter without spills",
+ size,
+ minNumIters = MIN_NUM_ITERS,
+ output = output)
+ addBenchmarkCase(benchmark,
+ "small dataset without spills",
+ size,
+ () => getWriter(Option.empty, Option.empty),
+ Some(0))
+ benchmark.run()
+ }
+
+ def writeBenchmarkWithSpill(): Unit = {
+ val size = DATA_SIZE_LARGE
+ val benchmark = new Benchmark("SortShuffleWriter with spills",
+ size,
+ minNumIters = MIN_NUM_ITERS,
+ output = output,
+ outputPerIteration = true)
+ addBenchmarkCase(benchmark,
+ "no map side combine",
+ size,
+ () => getWriter(Option.empty, Option.empty),
+ Some(7))
+
+ def createCombiner(i: String): String = i
+ def mergeValue(i: String, j: String): String = if (Ordering.String.compare(i, j) > 0) i else j
+ def mergeCombiners(i: String, j: String): String =
+ if (Ordering.String.compare(i, j) > 0) i else j
+ val aggregator =
+ new Aggregator[String, String, String](createCombiner, mergeValue, mergeCombiners)
+ addBenchmarkCase(benchmark,
+ "with map side aggregation",
+ size,
+ () => getWriter(Some(aggregator), Option.empty),
+ Some(7))
+
+ val sorter = Ordering.String
+ addBenchmarkCase(benchmark,
+ "with map side sort",
+ size,
+ () => getWriter(Option.empty, Some(sorter)),
+ Some(7))
+ benchmark.run()
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("SortShuffleWriter writer") {
+ writeBenchmarkWithSmallDataset()
+ writeBenchmarkWithSpill()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala
new file mode 100644
index 0000000000000..7066ba8fb44df
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle.sort
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport
+import org.apache.spark.storage.BlockManagerId
+
+/**
+ * Benchmark to measure performance for aggregate primitives.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/-results.txt".
+ * }}}
+ */
+object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase {
+
+ private val shuffleHandle: SerializedShuffleHandle[String, String] =
+ new SerializedShuffleHandle[String, String](0, 0, this.dependency)
+
+ private val MIN_NUM_ITERS = 10
+ private val DATA_SIZE_SMALL = 1000
+ private val DATA_SIZE_LARGE =
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES/2/DEFAULT_DATA_STRING_SIZE
+
+ def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = {
+ val conf = new SparkConf(loadDefaults = false)
+ conf.set("spark.file.transferTo", String.valueOf(transferTo))
+ val shuffleWriteSupport = new DefaultShuffleWriteSupport(
+ conf, blockResolver, BlockManagerId("0", "localhost", 9099))
+
+ TaskContext.setTaskContext(taskContext)
+ new UnsafeShuffleWriter[String, String](
+ blockManager,
+ taskMemoryManager,
+ shuffleHandle,
+ 0,
+ taskContext,
+ conf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleWriteSupport)
+ }
+
+ def writeBenchmarkWithSmallDataset(): Unit = {
+ val size = DATA_SIZE_SMALL
+ val benchmark = new Benchmark("UnsafeShuffleWriter without spills",
+ size,
+ minNumIters = MIN_NUM_ITERS,
+ output = output)
+ addBenchmarkCase(benchmark,
+ "small dataset without spills",
+ size,
+ () => getWriter(false),
+ Some(1)) // The single temp file is for the temp index file
+ benchmark.run()
+ }
+
+ def writeBenchmarkWithSpill(): Unit = {
+ val size = DATA_SIZE_LARGE
+ val benchmark = new Benchmark("UnsafeShuffleWriter with spills",
+ size,
+ minNumIters = MIN_NUM_ITERS,
+ output = output,
+ outputPerIteration = true)
+ addBenchmarkCase(benchmark, "without transferTo", size, () => getWriter(false), Some(7))
+ addBenchmarkCase(benchmark, "with transferTo", size, () => getWriter(true), Some(7))
+ benchmark.run()
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("UnsafeShuffleWriter write") {
+ writeBenchmarkWithSmallDataset()
+ writeBenchmarkWithSpill()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala
new file mode 100644
index 0000000000000..1f4ef0f203994
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io
+
+import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
+import java.math.BigInteger
+import java.nio.ByteBuffer
+import java.nio.channels.{Channels, WritableByteChannel}
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
+import org.mockito.Mock
+import org.mockito.Mockito.{doAnswer, doNothing, when}
+import org.mockito.MockitoAnnotations
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.api.shuffle.SupportsTransferTo
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.util.LimitedInputStream
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.Utils
+
+class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _
+
+ private val NUM_PARTITIONS = 4
+ private val D_LEN = 10
+ private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map {
+ p => (1 to D_LEN).map(_ + p).toArray }.toArray
+
+ private var tempFile: File = _
+ private var mergedOutputFile: File = _
+ private var tempDir: File = _
+ private var partitionSizesInMergedFile: Array[Long] = _
+ private var conf: SparkConf = _
+ private var mapOutputWriter: DefaultShuffleMapOutputWriter = _
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ override def beforeEach(): Unit = {
+ MockitoAnnotations.initMocks(this)
+ tempDir = Utils.createTempDir(null, "test")
+ mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir)
+ tempFile = File.createTempFile("tempfile", "", tempDir)
+ partitionSizesInMergedFile = null
+ conf = new SparkConf()
+ .set("spark.app.id", "example.spark.app")
+ .set("spark.shuffle.unsafe.file.output.buffer", "16k")
+ when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile)
+
+ doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong)
+
+ doAnswer(new Answer[Void] {
+ def answer(invocationOnMock: InvocationOnMock): Void = {
+ partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
+ val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ mergedOutputFile.delete
+ tmp.renameTo(mergedOutputFile)
+ }
+ null
+ }
+ }).when(blockResolver)
+ .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
+ mapOutputWriter = new DefaultShuffleMapOutputWriter(
+ 0,
+ 0,
+ NUM_PARTITIONS,
+ BlockManagerId("0", "localhost", 9099),
+ shuffleWriteMetrics,
+ blockResolver,
+ conf)
+ }
+
+ private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = {
+ var startOffset = 0L
+ val result = new Array[Array[Int]](NUM_PARTITIONS)
+ (0 until NUM_PARTITIONS).foreach { p =>
+ val partitionSize = partitionSizesInMergedFile(p).toInt
+ lazy val inner = new Array[Int](partitionSize)
+ lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize)
+ if (partitionSize > 0) {
+ val in = new FileInputStream(mergedOutputFile)
+ in.getChannel.position(startOffset)
+ val lin = new LimitedInputStream(in, partitionSize)
+ var nonEmpty = true
+ var count = 0
+ while (nonEmpty) {
+ try {
+ val readBit = lin.read()
+ if (fromByte) {
+ innerBytebuffer.put(readBit.toByte)
+ } else {
+ inner(count) = readBit
+ }
+ count += 1
+ } catch {
+ case _: Exception =>
+ nonEmpty = false
+ }
+ }
+ in.close()
+ }
+ if (fromByte) {
+ result(p) = innerBytebuffer.array().sliding(4, 4).map { b =>
+ new BigInteger(b).intValue()
+ }.toArray
+ } else {
+ result(p) = inner
+ }
+ startOffset += partitionSize
+ }
+ result
+ }
+
+ test("writing to an outputstream") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getPartitionWriter(p)
+ val stream = writer.openStream()
+ data(p).foreach { i => stream.write(i)}
+ stream.close()
+ intercept[IllegalStateException] {
+ stream.write(p)
+ }
+ assert(writer.getNumBytesWritten() == D_LEN)
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(false))
+ }
+
+ test("writing to a channel") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getPartitionWriter(p)
+ val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel()
+ val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
+ val intBuffer = byteBuffer.asIntBuffer()
+ intBuffer.put(data(p))
+ val numBytes = byteBuffer.remaining()
+ val outputTempFile = File.createTempFile("channelTemp", "", tempDir)
+ val outputTempFileStream = new FileOutputStream(outputTempFile)
+ Utils.copyStream(
+ new ByteBufferInputStream(byteBuffer),
+ outputTempFileStream,
+ closeStreams = true)
+ val tempFileInput = new FileInputStream(outputTempFile)
+ channel.transferFrom(tempFileInput.getChannel, 0L, numBytes)
+ // Bytes require * 4
+ channel.close()
+ tempFileInput.close()
+ assert(writer.getNumBytesWritten == D_LEN * 4)
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(true))
+ }
+
+ test("copyStreams with an outputstream") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getPartitionWriter(p)
+ val stream = writer.openStream()
+ val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
+ val intBuffer = byteBuffer.asIntBuffer()
+ intBuffer.put(data(p))
+ val in = new ByteArrayInputStream(byteBuffer.array())
+ Utils.copyStream(in, stream, false, false)
+ in.close()
+ stream.close()
+ assert(writer.getNumBytesWritten == D_LEN * 4)
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(true))
+ }
+
+ test("copyStreamsWithNIO with a channel") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getPartitionWriter(p)
+ val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel()
+ val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
+ val intBuffer = byteBuffer.asIntBuffer()
+ intBuffer.put(data(p))
+ val out = new FileOutputStream(tempFile)
+ out.write(byteBuffer.array())
+ out.close()
+ val in = new FileInputStream(tempFile)
+ channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining())
+ channel.close()
+ assert(writer.getNumBytesWritten == D_LEN * 4)
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(true))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 3ab2f0bf4596f..63cfc0b70e7b7 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -22,13 +22,12 @@ import java.nio.ByteBuffer
import java.util.UUID
import java.util.concurrent.Semaphore
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent.Future
-
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{mock, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.scalatest.PrivateMethodTester
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.Future
import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._
@@ -124,7 +123,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
- val (blockId, inputStream) = iterator.next()
+ val inputStream = iterator.next()
+ val blockId = iterator.getCurrentBlock()
// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
@@ -198,11 +198,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
taskContext.taskMetrics.createTempShuffleReadMetrics())
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
- iterator.next()._2.close() // close() first block's input stream
+ iterator.next().close() // close() first block's input stream
verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
// Get the 2nd block but do not exhaust the iterator
- val subIter = iterator.next()._2
+ val subIter = iterator.next()
// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
@@ -416,7 +416,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
sem.acquire()
// The first block should be returned without an exception
- val (id1, _) = iterator.next()
+ iterator.next()
+ val id1 = iterator.getCurrentBlock()
assert(id1 === ShuffleBlockId(0, 0, 0))
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
@@ -434,6 +435,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
sem.acquire()
+
intercept[FetchFailedException] { iterator.next() }
}
@@ -545,16 +547,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
true,
true,
taskContext.taskMetrics.createTempShuffleReadMetrics())
- val (id, st) = iterator.next()
- // Check that the test setup is correct -- make sure we have a concatenated stream.
- assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream])
-
- val dst = new DataInputStream(st)
- for (i <- 1 to 2500) {
- assert(i === dst.readInt())
- }
- assert(dst.read() === -1)
- dst.close()
+ // Blocks should be returned without exceptions.
+ iterator.next()
+ val blockId1 = iterator.getCurrentBlock()
+ iterator.next()
+ val blockId2 = iterator.getCurrentBlock()
+ assert(Set(blockId1, blockId2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0)))
}
test("retry corrupt blocks (disabled)") {
@@ -611,11 +609,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
sem.acquire()
// The first block should be returned without an exception
- val (id1, _) = iterator.next()
+ iterator.next()
+ val id1 = iterator.getCurrentBlock()
assert(id1 === ShuffleBlockId(0, 0, 0))
- val (id2, _) = iterator.next()
+ iterator.next()
+ val id2 = iterator.getCurrentBlock()
assert(id2 === ShuffleBlockId(0, 1, 0))
- val (id3, _) = iterator.next()
+ iterator.next()
+ val id3 = iterator.getCurrentBlock()
assert(id3 === ShuffleBlockId(0, 2, 0))
}
diff --git a/dev/run-spark-25299-benchmarks.sh b/dev/run-spark-25299-benchmarks.sh
new file mode 100755
index 0000000000000..2d60f9d5a06ec
--- /dev/null
+++ b/dev/run-spark-25299-benchmarks.sh
@@ -0,0 +1,94 @@
+#!/usr/bin/env bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Script to create a binary distribution for easy deploys of Spark.
+# The distribution directory defaults to dist/ but can be overridden below.
+# The distribution contains fat (assembly) jars that include the Scala library,
+# so it is completely self contained.
+# It does not contain source or *.class files.
+
+set -oue pipefail
+
+
+function usage {
+ echo "Usage: $(basename $0) [-h] [-u]"
+ echo ""
+ echo "Runs the perf tests and optionally uploads the results as a comment to a PR"
+ echo ""
+ echo " -h help"
+ echo " -u Upload the perf results as a comment"
+ # Exit as error for nesting scripts
+ exit 1
+}
+
+UPLOAD=false
+while getopts "hu" opt; do
+ case $opt in
+ h)
+ usage
+ exit 0;;
+ u)
+ UPLOAD=true;;
+ esac
+done
+
+echo "Running SPARK-25299 benchmarks"
+
+SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BlockStoreShuffleReaderBenchmark"
+SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriterBenchmark"
+SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.SortShuffleWriterBenchmark"
+SPARK_GENERATE_BENCHMARK_FILES=1 ./build/sbt "sql/test:runMain org.apache.spark.shuffle.sort.UnsafeShuffleWriterBenchmark"
+
+SPARK_DIR=`pwd`
+
+mkdir -p /tmp/artifacts
+cp $SPARK_DIR/sql/core/benchmarks/BlockStoreShuffleReaderBenchmark-results.txt /tmp/artifacts/
+cp $SPARK_DIR/sql/core/benchmarks/BypassMergeSortShuffleWriterBenchmark-results.txt /tmp/artifacts/
+cp $SPARK_DIR/sql/core/benchmarks/SortShuffleWriterBenchmark-results.txt /tmp/artifacts/
+cp $SPARK_DIR/sql/core/benchmarks/UnsafeShuffleWriterBenchmark-results.txt /tmp/artifacts/
+
+if [ "$UPLOAD" = false ]; then
+ exit 0
+fi
+
+IFS=
+RESULTS=""
+for benchmark_file in /tmp/artifacts/*.txt; do
+ RESULTS+=$(cat $benchmark_file)
+ RESULTS+=$'\n\n'
+done
+
+echo $RESULTS
+# Get last git message, filter out empty lines, get the last number of the first line. This is the PR number
+PULL_REQUEST_NUM=$(git log -1 --pretty=%B | awk NF | awk '{print $NF}' | head -1 | sed 's/(//g' | sed 's/)//g' | sed 's/#//g')
+
+
+USERNAME=svc-spark-25299
+PASSWORD=$SVC_SPARK_25299_PASSWORD
+message='{"body": "```'
+message+=$'\n'
+message+=$RESULTS
+message+=$'\n'
+json_message=$(echo $message | awk '{printf "%s\\n", $0}')
+json_message+='```", "event":"COMMENT"}'
+echo "$json_message" > benchmark_results.json
+
+echo "Sending benchmark requests to PR $PULL_REQUEST_NUM"
+curl -XPOST https://${USERNAME}:${PASSWORD}@api.github.com/repos/palantir/spark/pulls/${PULL_REQUEST_NUM}/reviews -d @benchmark_results.json
+rm benchmark_results.json
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 079ff25fcb67e..22cfbf506c645 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -156,10 +156,11 @@ class ShuffledRowRDD(
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
- val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
// `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
// as well as the `tempMetrics` for basic shuffle metrics.
- val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
+ context.taskMetrics().decorateTempShuffleReadMetrics(
+ tempMetrics => new SQLShuffleReadMetricsReporter(tempMetrics, metrics))
+ val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
// The range of pre-shuffle partitions that we are fetching at here is
// [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1].
val reader =
@@ -168,7 +169,7 @@ class ShuffledRowRDD(
shuffledRowPartition.startPreShufflePartitionIndex,
shuffledRowPartition.endPreShufflePartitionIndex,
context,
- sqlMetricsReporter)
+ tempMetrics)
reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
}