diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md index 3a96718dc0..f522581503 100644 --- a/docs/api/sql/Optimizer.md +++ b/docs/api/sql/Optimizer.md @@ -343,3 +343,5 @@ We can compare the metrics of querying the GeoParquet dataset with or without th | Without spatial predicate | With spatial predicate | | ----------- | ----------- | | ![](../../image/scan-parquet-without-spatial-pred.png) | ![](../../image/scan-parquet-with-spatial-pred.png) | + +Spatial predicate push-down to GeoParquet is enabled by default. Users can manually disable it by setting the Spark configuration `spark.sedona.geoparquet.spatialFilterPushDown` to `false`. diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala index 57ca2161d2..5aa782e5bd 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala @@ -29,20 +29,25 @@ import org.locationtech.jts.geom.Geometry */ trait GeoParquetSpatialFilter { def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean + def simpleString: String } object GeoParquetSpatialFilter { case class AndFilter(left: GeoParquetSpatialFilter, right: GeoParquetSpatialFilter) extends GeoParquetSpatialFilter { - override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = + override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = { left.evaluate(columns) && right.evaluate(columns) + } + + override def simpleString: String = s"(${left.simpleString}) AND (${right.simpleString})" } case class OrFilter(left: GeoParquetSpatialFilter, right: GeoParquetSpatialFilter) extends GeoParquetSpatialFilter { override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = left.evaluate(columns) || right.evaluate(columns) + override def simpleString: String = s"(${left.simpleString}) OR (${right.simpleString})" } /** @@ -77,5 +82,6 @@ object GeoParquetSpatialFilter { } } } + override def simpleString: String = s"$columnName ${predicateType.name} $queryWindow" } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala index c09f7947b8..ba0ecf8a40 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala @@ -62,23 +62,30 @@ import org.locationtech.jts.geom.Point class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, lr: LogicalRelation) if isGeoParquetRelation(lr) => - val filters = splitConjunctivePredicates(condition) - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, lr.output) - val (_, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) - val geoParquetSpatialFilters = - translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery) - val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation] - val fileFormat = hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormatBase] - if (geoParquetSpatialFilters.isEmpty) filter - else { - val combinedSpatialFilter = geoParquetSpatialFilters.reduce(AndFilter) - val newFileFormat = fileFormat.withSpatialPredicates(combinedSpatialFilter) - val newRelation = hadoopFsRelation.copy(fileFormat = newFileFormat)(sparkSession) - filter.copy(child = lr.copy(relation = newRelation)) + override def apply(plan: LogicalPlan): LogicalPlan = { + val enableSpatialFilterPushDown = + sparkSession.conf.get("spark.sedona.geoparquet.spatialFilterPushDown", "true").toBoolean + if (!enableSpatialFilterPushDown) plan + else { + plan transform { + case filter @ Filter(condition, lr: LogicalRelation) if isGeoParquetRelation(lr) => + val filters = splitConjunctivePredicates(condition) + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, lr.output) + val (_, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) + val geoParquetSpatialFilters = + translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery) + val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation] + val fileFormat = hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormatBase] + if (geoParquetSpatialFilters.isEmpty) filter + else { + val combinedSpatialFilter = geoParquetSpatialFilters.reduce(AndFilter) + val newFileFormat = fileFormat.withSpatialPredicates(combinedSpatialFilter) + val newRelation = hadoopFsRelation.copy(fileFormat = newFileFormat)(sparkSession) + filter.copy(child = lr.copy(relation = newRelation)) + } } + } } private def isGeoParquetRelation(lr: LogicalRelation): Boolean = diff --git a/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala index 702c6f31fb..1924bbfbaf 100644 --- a/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala +++ b/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala @@ -66,6 +66,14 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) override def hashCode(): Int = getClass.hashCode() + override def toString(): String = { + // HACK: This is the only place we can inject spatial filter information into the described query plan. + // Please see org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details. + "GeoParquet" + spatialFilter + .map(filter => " with spatial filter [" + filter.simpleString + "]") + .getOrElse("") + } + def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat = new GeoParquetFileFormat(Some(spatialFilter)) diff --git a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala index 8f3cc3f1e5..a2a257e8f5 100644 --- a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala +++ b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.apache.spark.sql.execution.SimpleMode import org.locationtech.jts.geom.Coordinate import org.locationtech.jts.geom.Geometry import org.locationtech.jts.geom.GeometryFactory @@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", Seq(1, 3)) } + + it("Explain geoparquet scan with spatial filter push-down") { + val dfFiltered = geoParquetDf.where( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))") + val explainString = dfFiltered.queryExecution.explainString(SimpleMode) + assert(explainString.contains("FileScan geoparquet")) + assert(explainString.contains("with spatial filter")) + } + + it("Manually disable spatial filter push-down") { + withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) { + val dfFiltered = geoParquetDf.where( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))") + val explainString = dfFiltered.queryExecution.explainString(SimpleMode) + assert(explainString.contains("FileScan geoparquet")) + assert(!explainString.contains("with spatial filter")) + assert(getPushedDownSpatialFilter(dfFiltered).isEmpty) + } + } } /** diff --git a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 2da12eceb0..5dd5d93091 100644 --- a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { def loadCsv(path: String): DataFrame = { sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) } + + def withConf[T](conf: Map[String, String])(f: => T): T = { + val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key)) + conf.foreach { case (key, value) => sparkSession.conf.set(key, value) } + try { + f + } finally { + oldConf.foreach { case (key, value) => + value match { + case Some(v) => sparkSession.conf.set(key, v) + case None => sparkSession.conf.unset(key) + } + } + } + } } diff --git a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala index dde566ba23..325a720982 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala @@ -65,6 +65,14 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) override def hashCode(): Int = getClass.hashCode() + override def toString(): String = { + // HACK: This is the only place we can inject spatial filter information into the described query plan. + // Please see org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details. + "GeoParquet" + spatialFilter + .map(filter => " with spatial filter [" + filter.simpleString + "]") + .getOrElse("") + } + def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat = new GeoParquetFileFormat(Some(spatialFilter)) diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala index 8f3cc3f1e5..a2a257e8f5 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.apache.spark.sql.execution.SimpleMode import org.locationtech.jts.geom.Coordinate import org.locationtech.jts.geom.Geometry import org.locationtech.jts.geom.GeometryFactory @@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", Seq(1, 3)) } + + it("Explain geoparquet scan with spatial filter push-down") { + val dfFiltered = geoParquetDf.where( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))") + val explainString = dfFiltered.queryExecution.explainString(SimpleMode) + assert(explainString.contains("FileScan geoparquet")) + assert(explainString.contains("with spatial filter")) + } + + it("Manually disable spatial filter push-down") { + withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) { + val dfFiltered = geoParquetDf.where( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))") + val explainString = dfFiltered.queryExecution.explainString(SimpleMode) + assert(explainString.contains("FileScan geoparquet")) + assert(!explainString.contains("with spatial filter")) + assert(getPushedDownSpatialFilter(dfFiltered).isEmpty) + } + } } /** diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 2da12eceb0..5dd5d93091 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { def loadCsv(path: String): DataFrame = { sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) } + + def withConf[T](conf: Map[String, String])(f: => T): T = { + val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key)) + conf.foreach { case (key, value) => sparkSession.conf.set(key, value) } + try { + f + } finally { + oldConf.foreach { case (key, value) => + value match { + case Some(v) => sparkSession.conf.set(key, v) + case None => sparkSession.conf.unset(key) + } + } + } + } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala index b8d422dceb..06c9683cd1 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala @@ -64,6 +64,14 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) override def hashCode(): Int = getClass.hashCode() + override def toString(): String = { + // HACK: This is the only place we can inject spatial filter information into the described query plan. + // Please see org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details. + "GeoParquet" + spatialFilter + .map(filter => " with spatial filter [" + filter.simpleString + "]") + .getOrElse("") + } + def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat = new GeoParquetFileFormat(Some(spatialFilter)) diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala index 8f3cc3f1e5..a2a257e8f5 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.apache.spark.sql.execution.SimpleMode import org.locationtech.jts.geom.Coordinate import org.locationtech.jts.geom.Geometry import org.locationtech.jts.geom.GeometryFactory @@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", Seq(1, 3)) } + + it("Explain geoparquet scan with spatial filter push-down") { + val dfFiltered = geoParquetDf.where( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))") + val explainString = dfFiltered.queryExecution.explainString(SimpleMode) + assert(explainString.contains("FileScan geoparquet")) + assert(explainString.contains("with spatial filter")) + } + + it("Manually disable spatial filter push-down") { + withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) { + val dfFiltered = geoParquetDf.where( + "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))") + val explainString = dfFiltered.queryExecution.explainString(SimpleMode) + assert(explainString.contains("FileScan geoparquet")) + assert(!explainString.contains("with spatial filter")) + assert(getPushedDownSpatialFilter(dfFiltered).isEmpty) + } + } } /** diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 2da12eceb0..5dd5d93091 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { def loadCsv(path: String): DataFrame = { sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path) } + + def withConf[T](conf: Map[String, String])(f: => T): T = { + val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key)) + conf.foreach { case (key, value) => sparkSession.conf.set(key, value) } + try { + f + } finally { + oldConf.foreach { case (key, value) => + value match { + case Some(v) => sparkSession.conf.set(key, v) + case None => sparkSession.conf.unset(key) + } + } + } + } }