Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-637] Show spatial filters pushed to GeoParquet scans in the query plan; allow disabling spatial filter pushdown #1540

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/sql/Optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,5 @@ We can compare the metrics of querying the GeoParquet dataset with or without th
| Without spatial predicate | With spatial predicate |
| ----------- | ----------- |
| ![](../../image/scan-parquet-without-spatial-pred.png) | ![](../../image/scan-parquet-with-spatial-pred.png) |

Spatial predicate push-down to GeoParquet is enabled by default. Users can manually disable it by setting the Spark configuration `spark.sedona.geoparquet.spatialFilterPushDown` to `false`.
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,25 @@ import org.locationtech.jts.geom.Geometry
*/
trait GeoParquetSpatialFilter {
def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean
def simpleString: String
}

object GeoParquetSpatialFilter {

case class AndFilter(left: GeoParquetSpatialFilter, right: GeoParquetSpatialFilter)
extends GeoParquetSpatialFilter {
override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean =
override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = {
left.evaluate(columns) && right.evaluate(columns)
}

override def simpleString: String = s"(${left.simpleString}) AND (${right.simpleString})"
}

case class OrFilter(left: GeoParquetSpatialFilter, right: GeoParquetSpatialFilter)
extends GeoParquetSpatialFilter {
override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean =
left.evaluate(columns) || right.evaluate(columns)
override def simpleString: String = s"(${left.simpleString}) OR (${right.simpleString})"
}

/**
Expand Down Expand Up @@ -77,5 +82,6 @@ object GeoParquetSpatialFilter {
}
}
}
override def simpleString: String = s"$columnName ${predicateType.name} $queryWindow"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,30 @@ import org.locationtech.jts.geom.Point

class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, lr: LogicalRelation) if isGeoParquetRelation(lr) =>
val filters = splitConjunctivePredicates(condition)
val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, lr.output)
val (_, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)
val geoParquetSpatialFilters =
translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery)
val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation]
val fileFormat = hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormatBase]
if (geoParquetSpatialFilters.isEmpty) filter
else {
val combinedSpatialFilter = geoParquetSpatialFilters.reduce(AndFilter)
val newFileFormat = fileFormat.withSpatialPredicates(combinedSpatialFilter)
val newRelation = hadoopFsRelation.copy(fileFormat = newFileFormat)(sparkSession)
filter.copy(child = lr.copy(relation = newRelation))
override def apply(plan: LogicalPlan): LogicalPlan = {
val enableSpatialFilterPushDown =
sparkSession.conf.get("spark.sedona.geoparquet.spatialFilterPushDown", "true").toBoolean
if (!enableSpatialFilterPushDown) plan
else {
plan transform {
case filter @ Filter(condition, lr: LogicalRelation) if isGeoParquetRelation(lr) =>
val filters = splitConjunctivePredicates(condition)
val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, lr.output)
val (_, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)
val geoParquetSpatialFilters =
translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery)
val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation]
val fileFormat = hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormatBase]
if (geoParquetSpatialFilters.isEmpty) filter
else {
val combinedSpatialFilter = geoParquetSpatialFilters.reduce(AndFilter)
val newFileFormat = fileFormat.withSpatialPredicates(combinedSpatialFilter)
val newRelation = hadoopFsRelation.copy(fileFormat = newFileFormat)(sparkSession)
filter.copy(child = lr.copy(relation = newRelation))
}
}
}
}

private def isGeoParquetRelation(lr: LogicalRelation): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter])

override def hashCode(): Int = getClass.hashCode()

override def toString(): String = {
// HACK: This is the only place we can inject spatial filter information into the described query plan.
// Please see org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details.
"GeoParquet" + spatialFilter
.map(filter => " with spatial filter [" + filter.simpleString + "]")
.getOrElse("")
}

def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat =
new GeoParquetFileFormat(Some(spatialFilter))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
import org.apache.spark.sql.execution.SimpleMode
import org.locationtech.jts.geom.Coordinate
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.GeometryFactory
Expand Down Expand Up @@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive
"id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))",
Seq(1, 3))
}

it("Explain geoparquet scan with spatial filter push-down") {
val dfFiltered = geoParquetDf.where(
"ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))")
val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
assert(explainString.contains("FileScan geoparquet"))
assert(explainString.contains("with spatial filter"))
}

it("Manually disable spatial filter push-down") {
withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) {
val dfFiltered = geoParquetDf.where(
"ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))")
val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
assert(explainString.contains("FileScan geoparquet"))
assert(!explainString.contains("with spatial filter"))
assert(getPushedDownSpatialFilter(dfFiltered).isEmpty)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
def loadCsv(path: String): DataFrame = {
sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path)
}

def withConf[T](conf: Map[String, String])(f: => T): T = {
val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key))
conf.foreach { case (key, value) => sparkSession.conf.set(key, value) }
try {
f
} finally {
oldConf.foreach { case (key, value) =>
value match {
case Some(v) => sparkSession.conf.set(key, v)
case None => sparkSession.conf.unset(key)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter])

override def hashCode(): Int = getClass.hashCode()

override def toString(): String = {
// HACK: This is the only place we can inject spatial filter information into the described query plan.
// Please see org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details.
"GeoParquet" + spatialFilter
.map(filter => " with spatial filter [" + filter.simpleString + "]")
.getOrElse("")
}

def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat =
new GeoParquetFileFormat(Some(spatialFilter))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
import org.apache.spark.sql.execution.SimpleMode
import org.locationtech.jts.geom.Coordinate
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.GeometryFactory
Expand Down Expand Up @@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive
"id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))",
Seq(1, 3))
}

it("Explain geoparquet scan with spatial filter push-down") {
val dfFiltered = geoParquetDf.where(
"ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))")
val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
assert(explainString.contains("FileScan geoparquet"))
assert(explainString.contains("with spatial filter"))
}

it("Manually disable spatial filter push-down") {
withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) {
val dfFiltered = geoParquetDf.where(
"ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))")
val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
assert(explainString.contains("FileScan geoparquet"))
assert(!explainString.contains("with spatial filter"))
assert(getPushedDownSpatialFilter(dfFiltered).isEmpty)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
def loadCsv(path: String): DataFrame = {
sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path)
}

def withConf[T](conf: Map[String, String])(f: => T): T = {
val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key))
conf.foreach { case (key, value) => sparkSession.conf.set(key, value) }
try {
f
} finally {
oldConf.foreach { case (key, value) =>
value match {
case Some(v) => sparkSession.conf.set(key, v)
case None => sparkSession.conf.unset(key)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter])

override def hashCode(): Int = getClass.hashCode()

override def toString(): String = {
// HACK: This is the only place we can inject spatial filter information into the described query plan.
// Please see org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details.
"GeoParquet" + spatialFilter
.map(filter => " with spatial filter [" + filter.simpleString + "]")
.getOrElse("")
}

def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat =
new GeoParquetFileFormat(Some(spatialFilter))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
import org.apache.spark.sql.execution.SimpleMode
import org.locationtech.jts.geom.Coordinate
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.GeometryFactory
Expand Down Expand Up @@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive
"id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))",
Seq(1, 3))
}

it("Explain geoparquet scan with spatial filter push-down") {
val dfFiltered = geoParquetDf.where(
"ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))")
val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
assert(explainString.contains("FileScan geoparquet"))
assert(explainString.contains("with spatial filter"))
}

it("Manually disable spatial filter push-down") {
withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) {
val dfFiltered = geoParquetDf.where(
"ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))")
val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
assert(explainString.contains("FileScan geoparquet"))
assert(!explainString.contains("with spatial filter"))
assert(getPushedDownSpatialFilter(dfFiltered).isEmpty)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
def loadCsv(path: String): DataFrame = {
sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path)
}

def withConf[T](conf: Map[String, String])(f: => T): T = {
val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key))
conf.foreach { case (key, value) => sparkSession.conf.set(key, value) }
try {
f
} finally {
oldConf.foreach { case (key, value) =>
value match {
case Some(v) => sparkSession.conf.set(key, v)
case None => sparkSession.conf.unset(key)
}
}
}
}
}
Loading