diff --git a/bin/pyspark b/bin/pyspark
index d3b512eeb1209..dd286277c1fc1 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
- exec "$PYSPARK_DRIVER_PYTHON" -m "$1"
+ exec "$PYSPARK_DRIVER_PYTHON" -m "$@"
exit
fi
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index c1325318d52fa..1a6515be51cff 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api-1.0.0-M20.jar
api-util-1.0.0-M20.jar
arpack_combined_all-0.1.jar
+arrow-format-0.4.0.jar
+arrow-memory-0.4.0.jar
+arrow-vector-0.4.0.jar
avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
@@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar
datanucleus-rdbms-3.2.9.jar
derby-10.12.1.1.jar
eigenbase-properties-1.1.5.jar
+flatbuffers-1.2.0-3f79e055.jar
gson-2.2.4.jar
guava-14.0.1.jar
guice-3.0.jar
@@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar
hk2-api-2.4.0-b34.jar
hk2-locator-2.4.0-b34.jar
hk2-utils-2.4.0-b34.jar
+hppc-0.7.1.jar
htrace-core-3.0.4.jar
httpclient-4.5.2.jar
httpcore-4.4.4.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index ac5abd21807b6..09e5a4288ca50 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api-1.0.0-M20.jar
api-util-1.0.0-M20.jar
arpack_combined_all-0.1.jar
+arrow-format-0.4.0.jar
+arrow-memory-0.4.0.jar
+arrow-vector-0.4.0.jar
avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
@@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar
datanucleus-rdbms-3.2.9.jar
derby-10.12.1.1.jar
eigenbase-properties-1.1.5.jar
+flatbuffers-1.2.0-3f79e055.jar
gson-2.2.4.jar
guava-14.0.1.jar
guice-3.0.jar
@@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar
hk2-api-2.4.0-b34.jar
hk2-locator-2.4.0-b34.jar
hk2-utils-2.4.0-b34.jar
+hppc-0.7.1.jar
htrace-core-3.1.0-incubating.jar
httpclient-4.5.2.jar
httpcore-4.4.4.jar
diff --git a/pom.xml b/pom.xml
index 5f524079495c0..f124ba45007b7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -181,6 +181,7 @@
2.6
1.8
1.0.0
+ 0.4.0
${java.home}
@@ -1878,6 +1879,25 @@
paranamer
${paranamer.version}
+
+ org.apache.arrow
+ arrow-vector
+ ${arrow.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+ io.netty
+ netty-handler
+
+
+
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ea5e00e9eeef5..d5c2a7518b18f 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -182,6 +182,23 @@ def loads(self, obj):
raise NotImplementedError
+class ArrowSerializer(FramedSerializer):
+ """
+ Serializes an Arrow stream.
+ """
+
+ def dumps(self, obj):
+ raise NotImplementedError
+
+ def loads(self, obj):
+ import pyarrow as pa
+ reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
+ return reader.read_all()
+
+ def __repr__(self):
+ return "ArrowSerializer"
+
+
class BatchedSerializer(Serializer):
"""
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 27a6dad8917d3..944739bcd2078 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,7 +29,8 @@
from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
-from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
+from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
+ UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
@@ -1710,7 +1711,8 @@ def toDF(self, *cols):
@since(1.3)
def toPandas(self):
- """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
+ """
+ Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
This is only available if Pandas is installed and available.
@@ -1723,18 +1725,42 @@ def toPandas(self):
1 5 Bob
"""
import pandas as pd
+ if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true":
+ try:
+ import pyarrow
+ tables = self._collectAsArrow()
+ if tables:
+ table = pyarrow.concat_tables(tables)
+ return table.to_pandas()
+ else:
+ return pd.DataFrame.from_records([], columns=self.columns)
+ except ImportError as e:
+ msg = "note: pyarrow must be installed and available on calling Python process " \
+ "if using spark.sql.execution.arrow.enable=true"
+ raise ImportError("%s\n%s" % (e.message, msg))
+ else:
+ dtype = {}
+ for field in self.schema:
+ pandas_type = _to_corrected_pandas_type(field.dataType)
+ if pandas_type is not None:
+ dtype[field.name] = pandas_type
- dtype = {}
- for field in self.schema:
- pandas_type = _to_corrected_pandas_type(field.dataType)
- if pandas_type is not None:
- dtype[field.name] = pandas_type
+ pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
- pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+ for f, t in dtype.items():
+ pdf[f] = pdf[f].astype(t, copy=False)
+ return pdf
- for f, t in dtype.items():
- pdf[f] = pdf[f].astype(t, copy=False)
- return pdf
+ def _collectAsArrow(self):
+ """
+ Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
+ and available.
+
+ .. note:: Experimental.
+ """
+ with SCCallSiteSync(self._sc) as css:
+ port = self._jdf.collectAsArrowToPython()
+ return list(_load_from_socket(port, ArrowSerializer()))
##########################################################################################
# Pandas compatibility
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 9db2f40474f70..bd8477e35f37a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -58,12 +58,21 @@
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
-from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
+from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
+_have_arrow = False
+try:
+ import pyarrow
+ _have_arrow = True
+except:
+ # No Arrow, but that's okay, we'll skip those tests
+ pass
+
+
class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
@@ -2843,6 +2852,73 @@ def __init__(self, **kwargs):
_make_type_verifier(data_type, nullable=False)(obj)
+@unittest.skipIf(not _have_arrow, "Arrow not installed")
+class ArrowTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+ cls.spark.conf.set("spark.sql.execution.arrow.enable", "true")
+ cls.schema = StructType([
+ StructField("1_str_t", StringType(), True),
+ StructField("2_int_t", IntegerType(), True),
+ StructField("3_long_t", LongType(), True),
+ StructField("4_float_t", FloatType(), True),
+ StructField("5_double_t", DoubleType(), True)])
+ cls.data = [("a", 1, 10, 0.2, 2.0),
+ ("b", 2, 20, 0.4, 4.0),
+ ("c", 3, 30, 0.8, 6.0)]
+
+ def assertFramesEqual(self, df_with_arrow, df_without):
+ msg = ("DataFrame from Arrow is not equal" +
+ ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
+ ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
+ self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
+
+ def test_unsupported_datatype(self):
+ schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
+ df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
+ with QuietTest(self.sc):
+ self.assertRaises(Exception, lambda: df.toPandas())
+
+ def test_null_conversion(self):
+ df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
+ self.data)
+ pdf = df_null.toPandas()
+ null_counts = pdf.isnull().sum().tolist()
+ self.assertTrue(all([c == 1 for c in null_counts]))
+
+ def test_toPandas_arrow_toggle(self):
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ self.spark.conf.set("spark.sql.execution.arrow.enable", "false")
+ pdf = df.toPandas()
+ self.spark.conf.set("spark.sql.execution.arrow.enable", "true")
+ pdf_arrow = df.toPandas()
+ self.assertFramesEqual(pdf_arrow, pdf)
+
+ def test_pandas_round_trip(self):
+ import pandas as pd
+ import numpy as np
+ data_dict = {}
+ for j, name in enumerate(self.schema.names):
+ data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+ # need to convert these to numpy types first
+ data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
+ data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
+ pdf = pd.DataFrame(data=data_dict)
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ pdf_arrow = df.toPandas()
+ self.assertFramesEqual(pdf_arrow, pdf)
+
+ def test_filtered_frame(self):
+ df = self.spark.range(3).toDF("i")
+ pdf = df.filter("i < 0").toPandas()
+ self.assertEqual(len(pdf.columns), 1)
+ self.assertEqual(pdf.columns[0], "i")
+ self.assertTrue(pdf.empty)
+
+
if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 25152f3e32d6b..643587a6eb09d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -855,6 +855,24 @@ object SQLConf {
.intConf
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+ val ARROW_EXECUTION_ENABLE =
+ buildConf("spark.sql.execution.arrow.enable")
+ .internal()
+ .doc("Make use of Apache Arrow for columnar data transfers. Currently available " +
+ "for use with pyspark.sql.DataFrame.toPandas with the following data types: " +
+ "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " +
+ "LongType, ShortType")
+ .booleanConf
+ .createWithDefault(false)
+
+ val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
+ buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
+ .internal()
+ .doc("When using Apache Arrow, limit the maximum number of records that can be written " +
+ "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.")
+ .intConf
+ .createWithDefault(10000)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -1115,6 +1133,10 @@ class SQLConf extends Serializable with Logging {
def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO)
+ def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
+
+ def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 1bc34a6b069d9..661c31ded7148 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -103,6 +103,10 @@
jackson-databind
${fasterxml.jackson.version}
+
+ org.apache.arrow
+ arrow-vector
+
org.apache.xbean
xbean-asm5-shaded
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index dfb51192c69bc..a7773831df075 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -46,6 +46,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
@@ -2907,6 +2908,16 @@ class Dataset[T] private[sql](
}
}
+ /**
+ * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+ */
+ private[sql] def collectAsArrowToPython(): Int = {
+ withNewExecutionId {
+ val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
+ PythonRDD.serveIterator(iter, "serve-Arrow")
+ }
+ }
+
private[sql] def toPythonIterator(): Int = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
@@ -2988,4 +2999,13 @@ class Dataset[T] private[sql](
Dataset(sparkSession, logicalPlan)
}
}
+
+ /** Convert to an RDD of ArrowPayload byte arrays */
+ private[sql] def toArrowPayload: RDD[ArrowPayload] = {
+ val schemaCaptured = this.schema
+ val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+ queryExecution.toRdd.mapPartitionsInternal { iter =>
+ ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
new file mode 100644
index 0000000000000..6af5c73422377
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -0,0 +1,429 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.arrow
+
+import java.io.ByteArrayOutputStream
+import java.nio.channels.Channels
+
+import scala.collection.JavaConverters._
+
+import io.netty.buffer.ArrowBuf
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.BaseValueVector.BaseMutator
+import org.apache.arrow.vector.file._
+import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
+import org.apache.arrow.vector.types.FloatingPointPrecision
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+
+/**
+ * Store Arrow data in a form that can be serialized by Spark and served to a Python process.
+ */
+private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
+
+ /**
+ * Convert the ArrowPayload to an ArrowRecordBatch.
+ */
+ def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = {
+ ArrowConverters.byteArrayToBatch(payload, allocator)
+ }
+
+ /**
+ * Get the ArrowPayload as a type that can be served to Python.
+ */
+ def asPythonSerializable: Array[Byte] = payload
+}
+
+private[sql] object ArrowPayload {
+
+ /**
+ * Create an ArrowPayload from an ArrowRecordBatch and Spark schema.
+ */
+ def apply(
+ batch: ArrowRecordBatch,
+ schema: StructType,
+ allocator: BufferAllocator): ArrowPayload = {
+ new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator))
+ }
+}
+
+private[sql] object ArrowConverters {
+
+ /**
+ * Map a Spark DataType to ArrowType.
+ */
+ private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = {
+ dataType match {
+ case BooleanType => ArrowType.Bool.INSTANCE
+ case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true)
+ case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true)
+ case LongType => new ArrowType.Int(8 * LongType.defaultSize, true)
+ case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+ case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ case ByteType => new ArrowType.Int(8, true)
+ case StringType => ArrowType.Utf8.INSTANCE
+ case BinaryType => ArrowType.Binary.INSTANCE
+ case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
+ }
+ }
+
+ /**
+ * Convert a Spark Dataset schema to Arrow schema.
+ */
+ private[arrow] def schemaToArrowSchema(schema: StructType): Schema = {
+ val arrowFields = schema.fields.map { f =>
+ new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava)
+ }
+ new Schema(arrowFields.toList.asJava)
+ }
+
+ /**
+ * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload
+ * by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
+ */
+ private[sql] def toPayloadIterator(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Int): Iterator[ArrowPayload] = {
+ new Iterator[ArrowPayload] {
+ private val _allocator = new RootAllocator(Long.MaxValue)
+ private var _nextPayload = if (rowIter.nonEmpty) convert() else null
+
+ override def hasNext: Boolean = _nextPayload != null
+
+ override def next(): ArrowPayload = {
+ val obj = _nextPayload
+ if (hasNext) {
+ if (rowIter.hasNext) {
+ _nextPayload = convert()
+ } else {
+ _allocator.close()
+ _nextPayload = null
+ }
+ }
+ obj
+ }
+
+ private def convert(): ArrowPayload = {
+ val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch)
+ ArrowPayload(batch, schema, _allocator)
+ }
+ }
+ }
+
+ /**
+ * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed
+ * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0,
+ * then rowIter will be fully consumed.
+ */
+ private def internalRowIterToArrowBatch(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ allocator: BufferAllocator,
+ maxRecordsPerBatch: Int = 0): ArrowRecordBatch = {
+
+ val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) =>
+ ColumnWriter(field.dataType, ordinal, allocator).init()
+ }
+
+ val writerLength = columnWriters.length
+ var recordsInBatch = 0
+ while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) {
+ val row = rowIter.next()
+ var i = 0
+ while (i < writerLength) {
+ columnWriters(i).write(row)
+ i += 1
+ }
+ recordsInBatch += 1
+ }
+
+ val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip
+ val buffers = bufferArrays.flatten
+
+ val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0
+ val recordBatch = new ArrowRecordBatch(rowLength,
+ fieldNodes.toList.asJava, buffers.toList.asJava)
+
+ buffers.foreach(_.release())
+ recordBatch
+ }
+
+ /**
+ * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed,
+ * the batch can no longer be used.
+ */
+ private[arrow] def batchToByteArray(
+ batch: ArrowRecordBatch,
+ schema: StructType,
+ allocator: BufferAllocator): Array[Byte] = {
+ val arrowSchema = ArrowConverters.schemaToArrowSchema(schema)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val out = new ByteArrayOutputStream()
+ val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
+
+ // Write a batch to byte stream, ensure the batch, allocator and writer are closed
+ Utils.tryWithSafeFinally {
+ val loader = new VectorLoader(root)
+ loader.load(batch)
+ writer.writeBatch() // writeBatch can throw IOException
+ } {
+ batch.close()
+ root.close()
+ writer.close()
+ }
+ out.toByteArray
+ }
+
+ /**
+ * Convert a byte array to an ArrowRecordBatch.
+ */
+ private[arrow] def byteArrayToBatch(
+ batchBytes: Array[Byte],
+ allocator: BufferAllocator): ArrowRecordBatch = {
+ val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
+ val reader = new ArrowFileReader(in, allocator)
+
+ // Read a batch from a byte stream, ensure the reader is closed
+ Utils.tryWithSafeFinally {
+ val root = reader.getVectorSchemaRoot // throws IOException
+ val unloader = new VectorUnloader(root)
+ reader.loadNextBatch() // throws IOException
+ unloader.getRecordBatch
+ } {
+ reader.close()
+ }
+ }
+}
+
+/**
+ * Interface for writing InternalRows to Arrow Buffers.
+ */
+private[arrow] trait ColumnWriter {
+ def init(): this.type
+ def write(row: InternalRow): Unit
+
+ /**
+ * Clear the column writer and return the ArrowFieldNode and ArrowBuf.
+ * This should be called only once after all the data is written.
+ */
+ def finish(): (ArrowFieldNode, Array[ArrowBuf])
+}
+
+/**
+ * Base class for flat arrow column writer, i.e., column without children.
+ */
+private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int)
+ extends ColumnWriter {
+
+ def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype)
+
+ def valueVector: BaseDataValueVector
+ def valueMutator: BaseMutator
+
+ def setNull(): Unit
+ def setValue(row: InternalRow): Unit
+
+ protected var count = 0
+ protected var nullCount = 0
+
+ override def init(): this.type = {
+ valueVector.allocateNew()
+ this
+ }
+
+ override def write(row: InternalRow): Unit = {
+ if (row.isNullAt(ordinal)) {
+ setNull()
+ nullCount += 1
+ } else {
+ setValue(row)
+ }
+ count += 1
+ }
+
+ override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = {
+ valueMutator.setValueCount(count)
+ val fieldNode = new ArrowFieldNode(count, nullCount)
+ val valueBuffers = valueVector.getBuffers(true)
+ (fieldNode, valueBuffers)
+ }
+}
+
+private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableBitVector
+ = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 )
+}
+
+private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableSmallIntVector
+ = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator)
+ override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, row.getShort(ordinal))
+}
+
+private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableIntVector
+ = new NullableIntVector("IntValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, row.getInt(ordinal))
+}
+
+private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableBigIntVector
+ = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, row.getLong(ordinal))
+}
+
+private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableFloat4Vector
+ = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, row.getFloat(ordinal))
+}
+
+private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableFloat8Vector
+ = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, row.getDouble(ordinal))
+}
+
+private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableUInt1Vector
+ = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit
+ = valueMutator.setSafe(count, row.getByte(ordinal))
+}
+
+private[arrow] class UTF8StringColumnWriter(
+ dtype: ArrowType,
+ ordinal: Int,
+ allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableVarCharVector
+ = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit = {
+ val str = row.getUTF8String(ordinal)
+ valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes)
+ }
+}
+
+private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableVarBinaryVector
+ = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit = {
+ val bytes = row.getBinary(ordinal)
+ valueMutator.setSafe(count, bytes, 0, bytes.length)
+ }
+}
+
+private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableDateDayVector
+ = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit = {
+ valueMutator.setSafe(count, row.getInt(ordinal))
+ }
+}
+
+private[arrow] class TimeStampColumnWriter(
+ dtype: ArrowType,
+ ordinal: Int,
+ allocator: BufferAllocator)
+ extends PrimitiveColumnWriter(ordinal) {
+ override val valueVector: NullableTimeStampMicroVector
+ = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator)
+ override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator
+
+ override def setNull(): Unit = valueMutator.setNull(count)
+ override def setValue(row: InternalRow): Unit = {
+ valueMutator.setSafe(count, row.getLong(ordinal))
+ }
+}
+
+private[arrow] object ColumnWriter {
+
+ /**
+ * Create an Arrow ColumnWriter given the type and ordinal of row.
+ */
+ def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = {
+ val dtype = ArrowConverters.sparkTypeToArrowType(dataType)
+ dataType match {
+ case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator)
+ case ShortType => new ShortColumnWriter(dtype, ordinal, allocator)
+ case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator)
+ case LongType => new LongColumnWriter(dtype, ordinal, allocator)
+ case FloatType => new FloatColumnWriter(dtype, ordinal, allocator)
+ case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator)
+ case ByteType => new ByteColumnWriter(dtype, ordinal, allocator)
+ case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator)
+ case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator)
+ case DateType => new DateColumnWriter(dtype, ordinal, allocator)
+ case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator)
+ case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
new file mode 100644
index 0000000000000..159328cc0d958
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -0,0 +1,1222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.arrow
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+import java.text.SimpleDateFormat
+import java.util.Locale
+
+import com.google.common.io.Files
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
+import org.apache.arrow.vector.file.json.JsonFileReader
+import org.apache.arrow.vector.util.Validator
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+
+class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
+ import testImplicits._
+
+ private var tempDataPath: String = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ tempDataPath = Utils.createTempDir(namePrefix = "arrow").getAbsolutePath
+ }
+
+ test("collect to arrow record batch") {
+ val indexData = (1 to 6).toDF("i")
+ val arrowPayloads = indexData.toArrowPayload.collect()
+ assert(arrowPayloads.nonEmpty)
+ assert(arrowPayloads.length == indexData.rdd.getNumPartitions)
+ val allocator = new RootAllocator(Long.MaxValue)
+ val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
+ val rowCount = arrowRecordBatches.map(_.getLength).sum
+ assert(rowCount === indexData.count())
+ arrowRecordBatches.foreach(batch => assert(batch.getNodes.size() > 0))
+ arrowRecordBatches.foreach(_.close())
+ allocator.close()
+ }
+
+ test("short conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_s",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 16
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 16
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_s",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 16
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 16
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_s",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, -1, 2, -2, 32767, -32768 ]
+ | }, {
+ | "name" : "b_s",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1, 0, 0, -2, 0, -32768 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val a_s = List[Short](1, -1, 2, -2, 32767, -32768)
+ val b_s = List[Option[Short]](Some(1), None, None, Some(-2), None, Some(-32768))
+ val df = a_s.zip(b_s).toDF("a_s", "b_s")
+
+ collectAndValidate(df, json, "integer-16bit.json")
+ }
+
+ test("int conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ]
+ | }, {
+ | "name" : "b_i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648)
+ val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648))
+ val df = a_i.zip(b_i).toDF("a_i", "b_i")
+
+ collectAndValidate(df, json, "integer-32bit.json")
+ }
+
+ test("long conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_l",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 64
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_l",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 64
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_l",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, -1, 2, -2, 9223372036854775807, -9223372036854775808 ]
+ | }, {
+ | "name" : "b_l",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1, 0, 0, -2, 0, -9223372036854775808 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val a_l = List[Long](1, -1, 2, -2, 9223372036854775807L, -9223372036854775808L)
+ val b_l = List[Option[Long]](Some(1), None, None, Some(-2), None, Some(-9223372036854775808L))
+ val df = a_l.zip(b_l).toDF("a_l", "b_l")
+
+ collectAndValidate(df, json, "integer-64bit.json")
+ }
+
+ test("float conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_f",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "SINGLE"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_f",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "SINGLE"
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_f",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0 ]
+ | }, {
+ | "name" : "b_f",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val a_f = List(1.0f, 2.0f, 0.01f, 200.0f, 0.0001f, 20000.0f)
+ val b_f = List[Option[Float]](Some(1.1f), None, None, Some(2.2f), None, Some(3.3f))
+ val df = a_f.zip(b_f).toDF("a_f", "b_f")
+
+ collectAndValidate(df, json, "floating_point-single_precision.json")
+ }
+
+ test("double conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_d",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "DOUBLE"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_d",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "DOUBLE"
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_d",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1.0, 2.0, 0.01, 200.0, 1.0E-4, 20000.0 ]
+ | }, {
+ | "name" : "b_d",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1.1, 0.0, 0.0, 2.2, 0.0, 3.3 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0)
+ val b_d = List[Option[Double]](Some(1.1), None, None, Some(2.2), None, Some(3.3))
+ val df = a_d.zip(b_d).toDF("a_d", "b_d")
+
+ collectAndValidate(df, json, "floating_point-double_precision.json")
+ }
+
+ test("index conversion") {
+ val data = List[Int](1, 2, 3, 4, 5, 6)
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3, 4, 5, 6 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+ val df = data.toDF("i")
+
+ collectAndValidate(df, json, "indexData-ints.json")
+ }
+
+ test("mixed numeric type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 16
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 16
+ | } ]
+ | }
+ | }, {
+ | "name" : "b",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "SINGLE"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "c",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "d",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "DOUBLE"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | }, {
+ | "name" : "e",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 64
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3, 4, 5, 6 ]
+ | }, {
+ | "name" : "b",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ]
+ | }, {
+ | "name" : "c",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3, 4, 5, 6 ]
+ | }, {
+ | "name" : "d",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ]
+ | }, {
+ | "name" : "e",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 3, 4, 5, 6 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val data = List(1, 2, 3, 4, 5, 6)
+ val data_tuples = for (d <- data) yield {
+ (d.toShort, d.toFloat, d.toInt, d.toDouble, d.toLong)
+ }
+ val df = data_tuples.toDF("a", "b", "c", "d", "e")
+
+ collectAndValidate(df, json, "mixed_numeric_types.json")
+ }
+
+ test("string type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "upper_case",
+ | "type" : {
+ | "name" : "utf8"
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 8
+ | } ]
+ | }
+ | }, {
+ | "name" : "lower_case",
+ | "type" : {
+ | "name" : "utf8"
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 8
+ | } ]
+ | }
+ | }, {
+ | "name" : "null_str",
+ | "type" : {
+ | "name" : "utf8"
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 8
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 3,
+ | "columns" : [ {
+ | "name" : "upper_case",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "OFFSET" : [ 0, 1, 2, 3 ],
+ | "DATA" : [ "A", "B", "C" ]
+ | }, {
+ | "name" : "lower_case",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "OFFSET" : [ 0, 1, 2, 3 ],
+ | "DATA" : [ "a", "b", "c" ]
+ | }, {
+ | "name" : "null_str",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 0 ],
+ | "OFFSET" : [ 0, 2, 5, 5 ],
+ | "DATA" : [ "ab", "CDE", "" ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val upperCase = Seq("A", "B", "C")
+ val lowerCase = Seq("a", "b", "c")
+ val nullStr = Seq("ab", "CDE", null)
+ val df = (upperCase, lowerCase, nullStr).zipped.toList
+ .toDF("upper_case", "lower_case", "null_str")
+
+ collectAndValidate(df, json, "stringData.json")
+ }
+
+ test("boolean type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_bool",
+ | "type" : {
+ | "name" : "bool"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 1
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 4,
+ | "columns" : [ {
+ | "name" : "a_bool",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "DATA" : [ true, true, false, true ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+ val df = Seq(true, true, false, true).toDF("a_bool")
+ collectAndValidate(df, json, "boolData.json")
+ }
+
+ test("byte type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_byte",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 8
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 8
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 4,
+ | "columns" : [ {
+ | "name" : "a_byte",
+ | "count" : 4,
+ | "VALIDITY" : [ 1, 1, 1, 1 ],
+ | "DATA" : [ 1, -1, 64, 127 ]
+ | } ]
+ | } ]
+ |}
+ |
+ """.stripMargin
+ val df = List[Byte](1.toByte, (-1).toByte, 64.toByte, Byte.MaxValue).toDF("a_byte")
+ collectAndValidate(df, json, "byteData.json")
+ }
+
+ test("binary type conversion") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_binary",
+ | "type" : {
+ | "name" : "binary"
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "OFFSET",
+ | "typeBitWidth" : 32
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 8
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 3,
+ | "columns" : [ {
+ | "name" : "a_binary",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "OFFSET" : [ 0, 3, 4, 6 ],
+ | "DATA" : [ "616263", "64", "6566" ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val data = Seq("abc", "d", "ef")
+ val rdd = sparkContext.parallelize(data.map(s => Row(s.getBytes("utf-8"))))
+ val df = spark.createDataFrame(rdd, StructType(Seq(StructField("a_binary", BinaryType))))
+
+ collectAndValidate(df, json, "binaryData.json")
+ }
+
+ test("floating-point NaN") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "NaN_f",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "SINGLE"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "NaN_d",
+ | "type" : {
+ | "name" : "floatingpoint",
+ | "precision" : "DOUBLE"
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 64
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 2,
+ | "columns" : [ {
+ | "name" : "NaN_f",
+ | "count" : 2,
+ | "VALIDITY" : [ 1, 1 ],
+ | "DATA" : [ 1.2000000476837158, "NaN" ]
+ | }, {
+ | "name" : "NaN_d",
+ | "count" : 2,
+ | "VALIDITY" : [ 1, 1 ],
+ | "DATA" : [ "NaN", 1.2 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val fnan = Seq(1.2F, Float.NaN)
+ val dnan = Seq(Double.NaN, 1.2)
+ val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d")
+
+ collectAndValidate(df, json, "nanData-floating_point.json")
+ }
+
+ test("partitioned DataFrame") {
+ val json1 =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "b",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 3,
+ | "columns" : [ {
+ | "name" : "a",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "DATA" : [ 1, 1, 2 ]
+ | }, {
+ | "name" : "b",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "DATA" : [ 1, 2, 1 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+ val json2 =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "b",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 3,
+ | "columns" : [ {
+ | "name" : "a",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "DATA" : [ 2, 3, 3 ]
+ | }, {
+ | "name" : "b",
+ | "count" : 3,
+ | "VALIDITY" : [ 1, 1, 1 ],
+ | "DATA" : [ 2, 1, 2 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val arrowPayloads = testData2.toArrowPayload.collect()
+ // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload
+ assert(arrowPayloads.length === 2)
+ val schema = testData2.schema
+
+ val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json")
+ val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json")
+ Files.write(json1, tempFile1, StandardCharsets.UTF_8)
+ Files.write(json2, tempFile2, StandardCharsets.UTF_8)
+
+ validateConversion(schema, arrowPayloads(0), tempFile1)
+ validateConversion(schema, arrowPayloads(1), tempFile2)
+ }
+
+ test("empty frame collect") {
+ val arrowPayload = spark.emptyDataFrame.toArrowPayload.collect()
+ assert(arrowPayload.isEmpty)
+
+ val filteredDF = List[Int](1, 2, 3, 4, 5, 6).toDF("i")
+ val filteredArrowPayload = filteredDF.filter("i < 0").toArrowPayload.collect()
+ assert(filteredArrowPayload.isEmpty)
+ }
+
+ test("empty partition collect") {
+ val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i")
+ val arrowPayloads = emptyPart.toArrowPayload.collect()
+ assert(arrowPayloads.length === 1)
+ val allocator = new RootAllocator(Long.MaxValue)
+ val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
+ assert(arrowRecordBatches.head.getLength == 1)
+ arrowRecordBatches.foreach(_.close())
+ allocator.close()
+ }
+
+ test("max records in batch conf") {
+ val totalRecords = 10
+ val maxRecordsPerBatch = 3
+ spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch)
+ val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i")
+ val arrowPayloads = df.toArrowPayload.collect()
+ val allocator = new RootAllocator(Long.MaxValue)
+ val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
+ var recordCount = 0
+ arrowRecordBatches.foreach { batch =>
+ assert(batch.getLength > 0)
+ assert(batch.getLength <= maxRecordsPerBatch)
+ recordCount += batch.getLength
+ batch.close()
+ }
+ assert(recordCount == totalRecords)
+ allocator.close()
+ spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
+ }
+
+ testQuietly("unsupported types") {
+ def runUnsupported(block: => Unit): Unit = {
+ val msg = intercept[SparkException] {
+ block
+ }
+ assert(msg.getMessage.contains("Unsupported data type"))
+ assert(msg.getCause.getClass === classOf[UnsupportedOperationException])
+ }
+
+ runUnsupported { decimalData.toArrowPayload.collect() }
+ runUnsupported { arrayData.toDF().toArrowPayload.collect() }
+ runUnsupported { mapData.toDF().toArrowPayload.collect() }
+ runUnsupported { complexData.toArrowPayload.collect() }
+
+ val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US)
+ val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime)
+ val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime)
+ runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() }
+
+ val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime)
+ val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime)
+ runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() }
+ }
+
+ test("test Arrow Validator") {
+ val json =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "a_i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "b_i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ]
+ | }, {
+ | "name" : "b_i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+ val json_diff_col_order =
+ s"""
+ |{
+ | "schema" : {
+ | "fields" : [ {
+ | "name" : "b_i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : true,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | }, {
+ | "name" : "a_i",
+ | "type" : {
+ | "name" : "int",
+ | "isSigned" : true,
+ | "bitWidth" : 32
+ | },
+ | "nullable" : false,
+ | "children" : [ ],
+ | "typeLayout" : {
+ | "vectors" : [ {
+ | "type" : "VALIDITY",
+ | "typeBitWidth" : 1
+ | }, {
+ | "type" : "DATA",
+ | "typeBitWidth" : 32
+ | } ]
+ | }
+ | } ]
+ | },
+ | "batches" : [ {
+ | "count" : 6,
+ | "columns" : [ {
+ | "name" : "a_i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 1, 1, 1, 1, 1 ],
+ | "DATA" : [ 1, -1, 2, -2, 2147483647, -2147483648 ]
+ | }, {
+ | "name" : "b_i",
+ | "count" : 6,
+ | "VALIDITY" : [ 1, 0, 0, 1, 0, 1 ],
+ | "DATA" : [ 1, 0, 0, -2, 0, -2147483648 ]
+ | } ]
+ | } ]
+ |}
+ """.stripMargin
+
+ val a_i = List[Int](1, -1, 2, -2, 2147483647, -2147483648)
+ val b_i = List[Option[Int]](Some(1), None, None, Some(-2), None, Some(-2147483648))
+ val df = a_i.zip(b_i).toDF("a_i", "b_i")
+
+ // Different schema
+ intercept[IllegalArgumentException] {
+ collectAndValidate(df, json_diff_col_order, "validator_diff_schema.json")
+ }
+
+ // Different values
+ intercept[IllegalArgumentException] {
+ collectAndValidate(df.sort($"a_i".desc), json, "validator_diff_values.json")
+ }
+ }
+
+ /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
+ private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = {
+ // NOTE: coalesce to single partition because can only load 1 batch in validator
+ val arrowPayload = df.coalesce(1).toArrowPayload.collect().head
+ val tempFile = new File(tempDataPath, file)
+ Files.write(json, tempFile, StandardCharsets.UTF_8)
+ validateConversion(df.schema, arrowPayload, tempFile)
+ }
+
+ private def validateConversion(
+ sparkSchema: StructType,
+ arrowPayload: ArrowPayload,
+ jsonFile: File): Unit = {
+ val allocator = new RootAllocator(Long.MaxValue)
+ val jsonReader = new JsonFileReader(jsonFile, allocator)
+
+ val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema)
+ val jsonSchema = jsonReader.start()
+ Validator.compareSchemas(arrowSchema, jsonSchema)
+
+ val arrowRoot = VectorSchemaRoot.create(arrowSchema, allocator)
+ val vectorLoader = new VectorLoader(arrowRoot)
+ val arrowRecordBatch = arrowPayload.loadBatch(allocator)
+ vectorLoader.load(arrowRecordBatch)
+ val jsonRoot = jsonReader.read()
+ Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)
+
+ jsonRoot.close()
+ jsonReader.close()
+ arrowRecordBatch.close()
+ arrowRoot.close()
+ allocator.close()
+ }
+}