diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 8815ac4eb..c215754b9 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -272,18 +272,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) - val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf( - s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") - .doc( - "The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. " + - "Compression can be disabled by setting spark.shuffle.compress=false.") - .stringConf - .checkValues(Set("zstd")) - .createWithDefault("zstd") + val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") + .doc("The codec of Comet native shuffle used to compress shuffle data. Only lz4 and zstd " + + "are supported. Compression can be disabled by setting spark.shuffle.compress=false.") + .stringConf + .checkValues(Set("zstd", "lz4")) + .createWithDefault("lz4") val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.level") - .doc("The compression level to use when compression shuffle files.") + .doc("The compression level to use when compressing shuffle files with zstd.") .intConf .createWithDefault(1) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala deleted file mode 100644 index d1d5af350..000000000 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/IpcInputStreamIterator.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet.execution.shuffle - -import java.io.EOFException -import java.io.InputStream -import java.nio.ByteBuffer -import java.nio.ByteOrder -import java.nio.channels.Channels -import java.nio.channels.ReadableByteChannel - -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.LimitedInputStream - -case class IpcInputStreamIterator( - var in: InputStream, - decompressingNeeded: Boolean, - taskContext: TaskContext) - extends Iterator[ReadableByteChannel] - with Logging { - - private[execution] val channel: ReadableByteChannel = if (in != null) { - Channels.newChannel(in) - } else { - null - } - - private val ipcLengthsBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) - - // NOTE: - // since all ipcs are sharing the same input stream and channel, the second - // hasNext() must be called after the first ipc has been completely processed. - - private[execution] var consumed = true - private var finished = false - private var currentIpcLength = 0L - private var currentLimitedInputStream: LimitedInputStream = _ - - taskContext.addTaskCompletionListener[Unit](_ => { - closeInputStream() - }) - - override def hasNext: Boolean = { - if (in == null || finished) { - return false - } - - // If we've read the length of the next IPC, we don't need to read it again. - if (!consumed) { - return true - } - - if (currentLimitedInputStream != null) { - currentLimitedInputStream.skip(Int.MaxValue) - currentLimitedInputStream = null - } - - // Reads the length of IPC bytes - ipcLengthsBuf.clear() - while (ipcLengthsBuf.hasRemaining && channel.read(ipcLengthsBuf) >= 0) {} - - // If we reach the end of the stream, we are done, or if we read partial length - // then the stream is corrupted. - if (ipcLengthsBuf.hasRemaining) { - if (ipcLengthsBuf.position() == 0) { - finished = true - closeInputStream() - return false - } - throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") - } - - ipcLengthsBuf.flip() - currentIpcLength = ipcLengthsBuf.getLong - - // Skips empty IPC - if (currentIpcLength == 0) { - return hasNext - } - consumed = false - return true - } - - override def next(): ReadableByteChannel = { - if (!hasNext) { - throw new NoSuchElementException - } - assert(!consumed) - consumed = true - - val is = new LimitedInputStream(Channels.newInputStream(channel), currentIpcLength, false) - currentLimitedInputStream = is - - if (decompressingNeeded) { - ShuffleUtils.compressionCodecForShuffling match { - case Some(codec) => Channels.newChannel(codec.compressedInputStream(is)) - case _ => Channels.newChannel(is) - } - } else { - Channels.newChannel(is) - } - } - - private def closeInputStream(): Unit = - synchronized { - if (in != null) { - in.close() - in = null - } - } -} diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 7881f0763..d6f861136 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -50,8 +50,8 @@ Comet provides the following configuration settings. | spark.comet.exec.memoryFraction | The fraction of memory from Comet memory overhead that the native memory manager can use for execution. The purpose of this config is to set aside memory for untracked data structures, as well as imprecise size estimation during memory acquisition. | 0.7 | | spark.comet.exec.project.enabled | Whether to enable project by default. | true | | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin for improved performance. This feature is not stable yet. For more information, refer to the Comet Tuning Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false | -| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. Compression can be disabled by setting spark.shuffle.compress=false. | zstd | -| spark.comet.exec.shuffle.compression.level | The compression level to use when compression shuffle files. | 1 | +| spark.comet.exec.shuffle.compression.codec | The codec of Comet native shuffle used to compress shuffle data. Only lz4 and zstd are supported. Compression can be disabled by setting spark.shuffle.compress=false. | lz4 | +| spark.comet.exec.shuffle.compression.level | The compression level to use when compressing shuffle files with zstd. | 1 | | spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle. Note that this requires setting 'spark.shuffle.manager' to 'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. 'spark.shuffle.manager' must be set before starting the Spark application and cannot be changed during the application. | true | | spark.comet.exec.sort.enabled | Whether to enable sort by default. | true | | spark.comet.exec.sortMergeJoin.enabled | Whether to enable sortMergeJoin by default. | true | diff --git a/native/Cargo.lock b/native/Cargo.lock index 538c40ee2..dfed0cc70 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -920,7 +920,7 @@ dependencies = [ "lazy_static", "log", "log4rs", - "lz4", + "lz4_flex", "mimalloc", "num", "once_cell", @@ -2111,25 +2111,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "lz4" -version = "1.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" -dependencies = [ - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.11.1+lz4-1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "lz4_flex" version = "0.11.3" diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 489da46d4..197bf4318 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -55,7 +55,7 @@ jni = "0.21" snap = "1.1" brotli = "3.3" flate2 = "1.0" -lz4 = "1.24" +lz4_flex = "0.11.3" zstd = "0.11" rand = { workspace = true} num = { workspace = true } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index eb73675b5..a90a91d2f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -17,6 +17,7 @@ //! Define JNI APIs which can be called from Java/Scala. +use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; use arrow::datatypes::DataType as ArrowDataType; use arrow_array::RecordBatch; use datafusion::{ @@ -40,8 +41,6 @@ use jni::{ use std::time::{Duration, Instant}; use std::{collections::HashMap, sync::Arc, task::Poll}; -use super::{serde, utils::SparkArrowConvert, CometMemoryPool}; - use crate::{ errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ @@ -60,6 +59,7 @@ use jni::{ use tokio::runtime::Runtime; use crate::execution::operators::ScanExec; +use crate::execution::shuffle::read_ipc_compressed; use crate::execution::spark_plan::SparkPlan; use log::info; @@ -95,7 +95,7 @@ struct ExecutionContext { /// Accept serialized query plan and return the address of the native query plan. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( e: JNIEnv, @@ -231,7 +231,7 @@ fn prepare_output( array_addrs: jlongArray, schema_addrs: jlongArray, output_batch: RecordBatch, - exec_context: &mut ExecutionContext, + debug_native: bool, ) -> CometResult { let array_address_array = unsafe { JLongArray::from_raw(array_addrs) }; let num_cols = env.get_array_length(&array_address_array)? as usize; @@ -255,7 +255,7 @@ fn prepare_output( ))); } - if exec_context.debug_native { + if debug_native { // Validate the output arrays. for array in results.iter() { let array_data = array.to_data(); @@ -275,9 +275,6 @@ fn prepare_output( i += 1; } - // Update metrics - update_metrics(env, exec_context)?; - Ok(num_rows as jlong) } @@ -298,7 +295,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr /// Accept serialized query plan and the addresses of Arrow Arrays from Spark, /// then execute the query. Return addresses of arrow vector. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( e: JNIEnv, @@ -358,12 +355,14 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( match poll_output { Poll::Ready(Some(output)) => { + // Update metrics + update_metrics(&mut env, exec_context)?; return prepare_output( &mut env, array_addrs, schema_addrs, output?, - exec_context, + exec_context.debug_native, ); } Poll::Ready(None) => { @@ -459,7 +458,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { /// Used by Comet shuffle external sorter to write sorted records to disk. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( e: JNIEnv, @@ -544,3 +543,25 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( Ok(()) }) } + +#[no_mangle] +/// Used by Comet native shuffle reader +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( + e: JNIEnv, + _class: JClass, + byte_array: jbyteArray, + array_addrs: jlongArray, + schema_addrs: jlongArray, +) -> jlong { + try_unwrap_or_throw(&e, |mut env| { + let value_array = unsafe { JPrimitiveArray::from_raw(byte_array) }; + let length = env.get_array_length(&value_array)?; + let elements = unsafe { env.get_array_elements(&value_array, ReleaseMode::NoCopyBack)? }; + let raw_pointer = elements.as_ptr(); + let slice = unsafe { std::slice::from_raw_parts(raw_pointer, length as usize) }; + let batch = read_ipc_compressed(slice)?; + prepare_output(&mut env, array_addrs, schema_addrs, batch, false) + }) +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0a7493354..199854b9c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1055,6 +1055,7 @@ impl PhysicalPlanner { Ok(SparkCompressionCodec::Zstd) => { Ok(CompressionCodec::Zstd(writer.compression_level)) } + Ok(SparkCompressionCodec::Lz4) => Ok(CompressionCodec::Lz4Frame), _ => Err(ExecutionError::GeneralError(format!( "Unsupported shuffle compression codec: {:?}", writer.codec diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index 8111f5eed..178aff1fa 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -19,4 +19,6 @@ mod list; mod map; pub mod row; mod shuffle_writer; -pub use shuffle_writer::{write_ipc_compressed, CompressionCodec, ShuffleWriterExec}; +pub use shuffle_writer::{ + read_ipc_compressed, write_ipc_compressed, CompressionCodec, ShuffleWriterExec, +}; diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index 01117199e..8ff8b9693 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -21,6 +21,7 @@ use crate::{ common::bit::ceil, errors::{CometError, CometResult}, }; +use arrow::ipc::reader::StreamReader; use arrow::{datatypes::*, ipc::writer::StreamWriter}; use async_trait::async_trait; use bytes::Buf; @@ -788,6 +789,8 @@ impl ShuffleRepartitioner { Partitioning::Hash(exprs, _) => { let (partition_starts, shuffled_partition_ids): (Vec, Vec) = { let mut timer = self.metrics.repart_time.timer(); + + // evaluate partition expressions let arrays = exprs .iter() .map(|expr| expr.evaluate(&input)?.into_array(input.num_rows())) @@ -1547,6 +1550,7 @@ impl Checksum { #[derive(Debug, Clone)] pub enum CompressionCodec { None, + Lz4Frame, Zstd(i32), } @@ -1565,9 +1569,20 @@ pub fn write_ipc_compressed( let mut timer = ipc_time.timer(); let start_pos = output.stream_position()?; - // write ipc_length placeholder + // write message length placeholder output.write_all(&[0u8; 8])?; + // write number of columns because JVM side needs to know how many addresses to allocate + let field_count = batch.schema().fields().len(); + output.write_all(&field_count.to_le_bytes())?; + + // write codec used + match codec { + &CompressionCodec::Lz4Frame => output.write_all("LZ4_".as_bytes())?, + &CompressionCodec::Zstd(_) => output.write_all("ZSTD".as_bytes())?, + &CompressionCodec::None => output.write_all("NONE".as_bytes())?, + } + let output = match codec { CompressionCodec::None => { let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; @@ -1575,6 +1590,23 @@ pub fn write_ipc_compressed( arrow_writer.finish()?; arrow_writer.into_inner()? } + CompressionCodec::Lz4Frame => { + // write IPC first without compression + let mut buffer = vec![]; + let mut arrow_writer = StreamWriter::try_new(&mut buffer, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + let ipc_encoded = arrow_writer.into_inner()?; + + // compress + let mut reader = Cursor::new(ipc_encoded); + let mut wtr = lz4_flex::frame::FrameEncoder::new(output); + std::io::copy(&mut reader, &mut wtr)?; + let output = wtr + .finish() + .map_err(|e| DataFusionError::Execution(format!("lz4 compression error: {}", e)))?; + output + } CompressionCodec::Zstd(level) => { let encoder = zstd::Encoder::new(output, *level)?; let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; @@ -1587,11 +1619,11 @@ pub fn write_ipc_compressed( // fill ipc length let end_pos = output.stream_position()?; - let ipc_length = end_pos - start_pos - 8; + let compressed_length = end_pos - start_pos - 8; // fill ipc length output.seek(SeekFrom::Start(start_pos))?; - output.write_all(&ipc_length.to_le_bytes()[..])?; + output.write_all(&compressed_length.to_le_bytes()[..])?; output.seek(SeekFrom::Start(end_pos))?; timer.stop(); @@ -1599,6 +1631,26 @@ pub fn write_ipc_compressed( Ok((end_pos - start_pos) as usize) } +pub fn read_ipc_compressed(bytes: &[u8]) -> Result { + match &bytes[0..4] { + b"LZ4_" => { + let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); + let mut reader = StreamReader::try_new(decoder, None)?; + // TODO check for None + reader.next().unwrap().map_err(|e| e.into()) + } + b"ZSTD" => { + let decoder = zstd::Decoder::new(&bytes[4..])?; + let mut reader = StreamReader::try_new(decoder, None)?; + // TODO check for None + reader.next().unwrap().map_err(|e| e.into()) + } + _ => Err(DataFusionError::Execution( + "invalid shuffle block codec".to_owned(), + )), + } +} + /// A stream that yields no record batches which represent end of output. pub struct EmptyStream { /// Schema representing the data @@ -1648,18 +1700,44 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn write_ipc_zstd() { + fn roundtrip_ipc_zstd() { let batch = create_batch(8192); let mut output = vec![]; let mut cursor = Cursor::new(&mut output); - write_ipc_compressed( + let length = write_ipc_compressed( &batch, &mut cursor, &CompressionCodec::Zstd(1), &Time::default(), ) .unwrap(); - assert_eq!(40218, output.len()); + assert_eq!(40230, output.len()); + assert_eq!(40230, length); + + let ipc_without_length_prefix = &output[16..]; + let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + assert_eq!(batch, batch2); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn roundtrip_ipc_lz4() { + let batch = create_batch(8192); + let mut output = vec![]; + let mut cursor = Cursor::new(&mut output); + let length = write_ipc_compressed( + &batch, + &mut cursor, + &CompressionCodec::Lz4Frame, + &Time::default(), + ) + .unwrap(); + assert_eq!(61756, output.len()); + assert_eq!(61756, length); + + let ipc_without_length_prefix = &output[16..]; + let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + assert_eq!(batch, batch2); } #[test] diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 5cb2802da..08c16a802 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -85,6 +85,7 @@ message Limit { enum CompressionCodec { None = 0; Zstd = 1; + Lz4 = 2; } message ShuffleWriter { diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java index 108e1f2e1..aa200aecf 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java @@ -51,6 +51,7 @@ import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.comet.shuffle.CometShuffleDependency; import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.FileSegment; diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java index 7690e1d8b..0eb7cc91e 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java @@ -66,6 +66,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager; import org.apache.spark.shuffle.sort.UnsafeShuffleWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.comet.shuffle.CometShuffleDependency; import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 8bff6b5fb..b7d61fd48 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.comet._ -import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager} +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, CometShuffleManager} +import org.apache.spark.sql.comet.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 083c0f2b5..e901058b2 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -139,4 +139,19 @@ class Native extends NativeBase { * the size of the array. */ @native def sortRowPartitionsNative(addr: Long, size: Long): Unit + + /** + * Decompress and decode a native shuffle block. + * @param shuffleBlock + * the encoded anc compressed shuffle block. + * @param addr + * the address of the array of compressed and encoded bytes. + * @param size + * the size of the array. + */ + @native def decodeShuffleBlock( + shuffleBlock: Array[Byte], + arrayAddrs: Array[Long], + schemaAddrs: Array[Long]): Long + } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index ccf218cf6..0a2016514 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.comet +import java.io.DataInputStream +import java.nio.channels.Channels import java.util.UUID import java.util.concurrent.{Future, TimeoutException, TimeUnit} @@ -26,13 +28,15 @@ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} +import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec +import org.apache.spark.io.CompressionCodec import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.comet.shuffle.ArrowReaderIterator import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} @@ -299,8 +303,24 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] partition.value.value.toIterator - .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName)) + .flatMap(decodeBatches(_, this.getClass.getSimpleName)) } + + /** + * Decodes the byte arrays back to ColumnarBatches and put them into buffer. + */ + private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { + if (bytes.size == 0) { + return Iterator.empty + } + + // decompress with Spark codec not Comet so this is not compatible with shuffle + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + new ArrowReaderIterator(Channels.newChannel(ins), source) + } + } class CometBatchPartition( diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index f75af5076..cf3eb08e4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -24,7 +24,8 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} -import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.shuffle.CometShuffledBatchRDD import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 19586628a..f3ec11341 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -24,7 +24,8 @@ import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.shuffle.CometShuffledBatchRDD import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 3a11b8b28..6c8ef01b9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan} import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor +import org.apache.spark.sql.comet.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffledBatchRDD, CometShuffleDependency, ShuffleType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -557,6 +558,7 @@ class CometShuffleWriteProcessor( if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 case other => throw new UnsupportedOperationException(s"invalid codec: $other") } shuffleWriterBuilder.setCodec(codec) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index b2cc2c2ba..e24d532ec 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -29,11 +29,10 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkEnv import org.apache.spark.TaskContext import org.apache.spark.internal.{config, Logging} -import org.apache.spark.internal.config.IO_COMPRESSION_CODEC -import org.apache.spark.io.CompressionCodec import org.apache.spark.shuffle._ import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.{BypassMergeSortShuffleHandle, SerializedShuffleHandle, SortShuffleManager, SortShuffleWriter} +import org.apache.spark.sql.comet.shuffle.{CometBlockStoreShuffleReader, CometShuffleDependency} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.OpenHashSet @@ -241,18 +240,6 @@ object CometShuffleManager extends Logging { executorComponents } - lazy val compressionCodecForShuffling: CompressionCodec = { - val sparkConf = SparkEnv.get.conf - val codecName = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get(SQLConf.get) - - // only zstd compression is supported at the moment - if (codecName != "zstd") { - logWarning( - s"Overriding config ${IO_COMPRESSION_CODEC}=${codecName} in shuffling, force using zstd") - } - CompressionCodec.createCodec(sparkConf, "zstd") - } - def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 77188312e..c70f7464e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.comet -import java.io.{ByteArrayOutputStream, DataInputStream} -import java.nio.channels.Channels +import java.io.ByteArrayOutputStream import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.io.CompressionCodec +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} @@ -34,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} -import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} @@ -78,18 +76,6 @@ abstract class CometExec extends CometPlan { // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec. override def outputPartitioning: Partitioning = originalPlan.outputPartitioning - /** - * Executes the Comet operator and returns the result as an iterator of ColumnarBatch. - */ - def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = { - val countsAndBytes = CometExec.getByteArrayRdd(this).collect() - val total = countsAndBytes.map(_._1).sum - val rows = countsAndBytes.iterator - .flatMap(countAndBytes => - CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName)) - (total, rows) - } - protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = { sparkPlan.children.foreach(setSubqueries(planId, _)) @@ -161,21 +147,6 @@ object CometExec { Utils.serializeBatches(iter) } } - - /** - * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. - */ - def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { - if (bytes.size == 0) { - return Iterator.empty - } - - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val cbbis = bytes.toInputStream() - val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - - new ArrowReaderIterator(Channels.newChannel(ins), source) - } } /** diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ArrowReaderIterator.scala similarity index 97% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ArrowReaderIterator.scala index 933e0b661..8b62963cf 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ArrowReaderIterator.scala @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.spark.sql.comet.execution.shuffle +package org.apache.spark.sql.comet.shuffle import java.nio.channels.ReadableByteChannel diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala similarity index 83% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala index e026cbeb1..9680e5f70 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala @@ -17,25 +17,17 @@ * under the License. */ -package org.apache.spark.sql.comet.execution.shuffle +package org.apache.spark.sql.comet.shuffle import java.io.InputStream -import org.apache.spark.InterruptibleIterator -import org.apache.spark.MapOutputTracker -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager -import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.shuffle.ShuffleReader -import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.storage.BlockId -import org.apache.spark.storage.BlockManager -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.storage.ShuffleBlockFetcherIterator +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader, ShuffleReadMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator /** @@ -88,7 +80,7 @@ class CometBlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - var currentReadIterator: ArrowReaderIterator = null + var currentReadIterator: NativeBatchDecoderIterator = null // Closes last read iterator after the task is finished. // We need to close read iterator during iterating input streams, @@ -100,18 +92,15 @@ class CometBlockStoreShuffleReader[K, C]( } } - val recordIter = fetchIterator - .flatMap { case (_, inputStream) => - IpcInputStreamIterator(inputStream, decompressingNeeded = true, context) - .flatMap { channel => - if (currentReadIterator != null) { - // Closes previous read iterator. - currentReadIterator.close() - } - currentReadIterator = new ArrowReaderIterator(channel, this.getClass.getSimpleName) - currentReadIterator.map((0, _)) // use 0 as key since it's not used - } - } + val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator + .flatMap(blockIdAndStream => { + if (currentReadIterator != null) { + currentReadIterator.close() + } + currentReadIterator = NativeBatchDecoderIterator(blockIdAndStream._2, context) + currentReadIterator + }) + .map(b => (0, b)) // Update the context task metrics for each record read. val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometShuffleDependency.scala similarity index 97% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometShuffleDependency.scala index 7b1d1f127..345457545 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometShuffleDependency.scala @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.spark.sql.comet.execution.shuffle +package org.apache.spark.sql.comet.shuffle import scala.reflect.ClassTag diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometShuffledRowRDD.scala similarity index 94% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometShuffledRowRDD.scala index af78ed290..35b3919ae 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometShuffledRowRDD.scala @@ -17,12 +17,12 @@ * under the License. */ -package org.apache.spark.sql.comet.execution.shuffle +package org.apache.spark.sql.comet.shuffle -import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.execution.{CoalescedMapperPartitionSpec, CoalescedPartitioner, CoalescedPartitionSpec, PartialMapperPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/NativeBatchDecoderIterator.scala new file mode 100644 index 000000000..3f16d1361 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/NativeBatchDecoderIterator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shuffle + +import java.io.{EOFException, InputStream} +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.channels.{Channels, ReadableByteChannel} + +import org.apache.spark.TaskContext +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.Native +import org.apache.comet.vector.NativeUtil + +/** + * This iterator wraps a Spark input stream that is reading shuffle blocks generated by the Comet + * native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks + * and use Arrow FFI to return the Arrow record batch. + */ +case class NativeBatchDecoderIterator(var in: InputStream, taskContext: TaskContext) + extends Iterator[ColumnarBatch] { + private val SPARK_LZ4_MAGIC = Array[Byte](76, 90, 52, 66, 108, 111, 99, 107) // "LZ4Block" + private var nextBatch: Option[ColumnarBatch] = None + private var finished = false; + private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + private val native = new Native() + private val nativeUtil = new NativeUtil() + + if (taskContext != null) { + taskContext.addTaskCompletionListener[Unit](_ => { + close() + }) + } + + private val channel: ReadableByteChannel = if (in != null) { + Channels.newChannel(in) + } else { + null + } + + def hasNext(): Boolean = { + if (channel == null || finished) { + return false + } + if (nextBatch.isDefined) { + return true + } + + // read compressed batch size from header + try { + longBuf.clear() + while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} + } catch { + case _: EOFException => + close() + return false + } + + // If we reach the end of the stream, we are done, or if we read partial length + // then the stream is corrupted. + if (longBuf.hasRemaining) { + if (longBuf.position() == 0) { + finished = true + close() + return false + } + throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") + } + + // make troubleshooting easier + if (longBuf.array().sameElements(SPARK_LZ4_MAGIC)) { + throw new IllegalStateException( + "Attempting to read Spark LZ4 stream with Comet shuffle block decoder") + } + + // get compressed length (including headers) + longBuf.flip() + val compressedLength = longBuf.getLong.toInt + + // read field count from header + longBuf.clear() + while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} + longBuf.flip() + val fieldCount = longBuf.getLong.toInt + + // read body + val buffer = new Array[Byte](compressedLength - 8) + fillBuffer(in, buffer) + + // make native call to decode batch + nextBatch = nativeUtil.getNextBatch( + fieldCount, + (arrayAddrs, schemaAddrs) => { + native.decodeShuffleBlock(buffer, arrayAddrs, schemaAddrs) + }) + + true + } + + def next(): ColumnarBatch = { + if (nextBatch.isDefined) { + val ret = nextBatch.get + nextBatch = None + ret + } else { + throw new IllegalStateException() + } + } + + private def fillBuffer(in: InputStream, buffer: Array[Byte]): Unit = { + var bytesRead = 0 + while (bytesRead < buffer.length) { + val result = in.read(buffer, bytesRead, buffer.length - bytesRead) + if (result == -1) { + throw new EOFException(s"Expected ${buffer.length} bytes, only $bytesRead available") + } + bytesRead += result + } + } + + def close(): Unit = { + synchronized { + if (in != null) { + in.close() + in = null + } + finished = true + } + } + +} diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleUtils.scala similarity index 97% rename from common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleUtils.scala index 23b4a5ec2..3d713ddb9 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ShuffleUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleUtils.scala @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.spark.sql.comet.execution.shuffle +package org.apache.spark.sql.comet.shuffle import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala index 965b6851e..2f50263cc 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -21,7 +21,8 @@ package org.apache.comet.shims import org.apache.spark.ShuffleDependency import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.shuffle.ShuffleType import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.types.{StructField, StructType} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala index 559e327b4..7a3f35154 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -22,7 +22,8 @@ package org.apache.comet.shims import org.apache.spark.ShuffleDependency import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, ShuffleType} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.shuffle.ShuffleType import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.types.StructType diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 6130e4cd5..418cc7a38 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -27,7 +27,8 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} import org.apache.spark.sql.{CometTestBase, DataFrame, RandomDataGenerator, Row} -import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} +import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, CometShuffleManager} +import org.apache.spark.sql.comet.shuffle.CometShuffleDependency import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 102769537..aeaf89136 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -22,8 +22,6 @@ package org.apache.comet.exec import java.sql.Date import java.time.{Duration, Period} -import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.util.Random import org.scalactic.source.Position @@ -36,7 +34,8 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, Cat import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Hex} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometSparkToColumnarExec, CometTakeOrderedAndProjectExec} -import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.shuffle.CometColumnarShuffle import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -462,37 +461,6 @@ class CometExecSuite extends CometTestBase { } } - test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") { - assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") { - withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") { - val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3") - - val nativeProject = find(df.queryExecution.executedPlan) { - case _: CometProjectExec => true - case _ => false - }.get.asInstanceOf[CometProjectExec] - - val (rows, batches) = nativeProject.executeColumnarCollectIterator() - assert(rows == 46) - - val column1 = mutable.ArrayBuffer.empty[Int] - val column2 = mutable.ArrayBuffer.empty[Int] - - batches.foreach(batch => { - batch.rowIterator().asScala.foreach { row => - assert(row.numFields == 2) - column1 += row.getInt(0) - column2 += row.getInt(1) - } - }) - - assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray) - assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray) - } - } - } - test("scalar subquery") { val dataTypes = Seq( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 213ec7efe..130c96939 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -37,7 +37,8 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} -import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal._