From fcdc8a99f9f0e2efbea4ce517fc72d327da86540 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 30 Jun 2023 18:41:48 +0100 Subject: [PATCH] Fix coercion of cell geometries. --- .../labs/mosaic/core/index/IndexSystem.scala | 9 ++- .../index/MosaicExplodeBehaviors.scala | 79 ++++++++++++++----- .../expressions/index/MosaicExplodeTest.scala | 1 + 3 files changed, 67 insertions(+), 22 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala index 4cac86e49..0144760a0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/index/IndexSystem.scala @@ -160,7 +160,7 @@ abstract class IndexSystem(var cellIdType: DataType) extends Serializable { val intersections = for (index <- borderIndices) yield { val indexGeom = indexToGeometry(index, geometryAPI) val intersect = geometry.intersection(indexGeom) - val coerced = coerceChipGeometry(intersect, index, geometryAPI) + val coerced = coerceChipGeometry(intersect, indexGeom, geometry) val isCore = coerced.equals(indexGeom) val chipGeom = if (!isCore || keepCoreGeom) coerced else null @@ -276,12 +276,13 @@ abstract class IndexSystem(var cellIdType: DataType) extends Serializable { def area(index: String): Double = area(parse(index)) - def coerceChipGeometry(geom: MosaicGeometry, cell: Long, geometryAPI: GeometryAPI): MosaicGeometry = { + def coerceChipGeometry(geom: MosaicGeometry, indexGeom: MosaicGeometry, originGeom: MosaicGeometry): MosaicGeometry = { val geomType = GeometryTypeEnum.fromString(geom.getGeometryType) - if (geomType == GEOMETRYCOLLECTION) { + val originGeomType = GeometryTypeEnum.fromString(originGeom.getGeometryType) + if (geomType == GEOMETRYCOLLECTION || geomType != originGeomType) { // This case can occur if partial geometry is a geometry collection // or if the intersection includes a part of the boundary of the cell - geom.difference(indexToGeometry(cell, geometryAPI).getBoundary) + geom.difference(indexGeom.getBoundary) } else { geom } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala index 4d0166af1..dd5b71718 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeBehaviors.scala @@ -21,9 +21,9 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { mc.register(spark) val resolution = mc.getIndexSystem match { - case H3IndexSystem => 3 + case H3IndexSystem => 3 case BNGIndexSystem => 5 - case _ => 3 + case _ => 3 } val boroughs: DataFrame = getBoroughs(mc) @@ -56,7 +56,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resolution = mc.getIndexSystem match { case H3IndexSystem => 3 case BNGIndexSystem => 5 - case _ => 3 + case _ => 3 } val rdd = spark.sparkContext.makeRDD( @@ -112,7 +112,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resolution = mc.getIndexSystem match { case H3IndexSystem => 3 case BNGIndexSystem => 5 - case _ => 3 + case _ => 3 } val rdd = spark.sparkContext.makeRDD( @@ -149,7 +149,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resolution = mc.getIndexSystem match { case H3IndexSystem => 3 case BNGIndexSystem => 3 - case _ => 3 + case _ => 3 } val wktRows: DataFrame = getWKTRowsDf(mc.getIndexSystem).where(col("wkt").contains("LINESTRING")) @@ -202,7 +202,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { ) val res = noEmptyChips.collect() res.length should be > 0 - case _ => // do nothing + case _ => // do nothing } } @@ -216,7 +216,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resolution = mc.getIndexSystem match { case H3IndexSystem => 3 case BNGIndexSystem => 5 - case _ => 3 + case _ => 3 } val boroughs: DataFrame = getBoroughs(mc) @@ -249,7 +249,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resolution = mc.getIndexSystem match { case H3IndexSystem => 3 case BNGIndexSystem => 5 - case _ => 3 + case _ => 3 } val boroughs: DataFrame = getBoroughs(mc) @@ -282,7 +282,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resolution = mc.getIndexSystem match { case H3IndexSystem => 3 case BNGIndexSystem => 5 - case _ => 3 + case _ => 3 } val boroughs: DataFrame = getBoroughs(mc) @@ -313,7 +313,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { noException should be thrownBy funcs.grid_tessellateexplode(col("wkt"), 3, keepCoreGeometries = true) noException should be thrownBy funcs.grid_tessellateexplode(col("wkt"), 3, lit(false)) noException should be thrownBy funcs.grid_tessellateexplode(col("wkt"), lit(3), lit(false)) - //legacy APIs + // legacy APIs noException should be thrownBy funcs.mosaic_explode(col("wkt"), 3) noException should be thrownBy funcs.mosaic_explode(col("wkt"), lit(3)) noException should be thrownBy funcs.mosaic_explode(col("wkt"), 3, keepCoreGeometries = true) @@ -332,7 +332,7 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { val resExpr = mc.getIndexSystem match { case H3IndexSystem => lit(mc.getIndexSystem.resolutions.head).expr case BNGIndexSystem => lit("100m").expr - case _ => lit(3).expr + case _ => lit(3).expr } val mosaicExplodeExpr = MosaicExplode( @@ -390,14 +390,14 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { mc.getIndexSystem match { case H3IndexSystem => val rdd = spark.sparkContext.makeRDD( - Seq( - Row("LINESTRING (-85.0040681 42.2975028, -85.0073029 42.2975266)") - ) + Seq( + Row("LINESTRING (-85.0040681 42.2975028, -85.0073029 42.2975266)") + ) ) val schema = StructType( - List( - StructField("wkt", StringType) - ) + List( + StructField("wkt", StringType) + ) ) val df = spark.createDataFrame(rdd, schema) @@ -408,7 +408,50 @@ trait MosaicExplodeBehaviors extends MosaicSpatialQueryTest { df.select(expr(s"grid_tessellateexplode(wkt, 13, true)")) .collect() .length shouldEqual 48 - case _ => // do nothing + case _ => // do nothing } } + + def issue382(mosaicContext: MosaicContext): Unit = { + assume(mosaicContext.getIndexSystem == H3IndexSystem) + val sc = spark + import sc.implicits._ + import mosaicContext.functions._ + + val wkt = "POLYGON ((-8.522721910163417 53.40846416712235, -8.522828495418493 53.40871094834742," + + " -8.523239522405696 53.40879676331252, -8.52334611088906 53.409043543609435," + + " -8.523757142297253 53.409129356978674, -8.523863734008978 53.409376136347404," + + " -8.523559290871438 53.40953710231036, -8.523665882370468 53.40978388071435," + + " -8.523361436771772 53.40994484500841, -8.523468028058108 53.41019162244766," + + " -8.523163579998224 53.410352585072815, -8.52275254184475 53.41026676959102," + + " -8.52244809251643 53.41042772987954, -8.522037056535808 53.41034191209765," + + " -8.521732605939153 53.41050287004956, -8.52132157213149 53.41041704996761," + + " -8.521214991168797 53.410170272637956, -8.520803961782489 53.41008445096018," + + " -8.520697384048132 53.40983767270238, -8.520286359083132 53.40975184942885," + + " -8.520179784577046 53.409505070242936, -8.520484231594777 53.40934411429393," + + " -8.52037765687575 53.409097334143304, -8.520682101432444 53.40893637652535," + + " -8.520575526500501 53.408689595410024, -8.520879968596168 53.40852863612313," + + " -8.521290986816735 53.408614457283946, -8.521595427644318 53.408453495660524," + + " -8.522006448037782 53.40853931452139, -8.522310887597179 53.408378350561435," + + " -8.522721910163417 53.40846416712235))" + + val rdd = spark.sparkContext.makeRDD(Seq(Row(wkt))) + val schema = StructType(List(StructField("wkt", StringType))) + val df = spark.createDataFrame(rdd, schema) + + val result = df + .select(grid_tessellateexplode(col("wkt"), 11).alias("grid")) + .select(col("grid.wkb")) + .select(st_aswkt(col("wkb"))) + + val chips = result.as[String].collect() + val resultGeom = chips.map(mosaicContext.getGeometryAPI.geometry(_, "WKT")) + .reduce(_ union _) + + val expected = mosaicContext.getGeometryAPI.geometry(wkt, "WKT") + + math.abs(expected.getArea - resultGeom.getArea) should be < 1e-8 + + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeTest.scala index 67c3d0e0e..3ed35f825 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/index/MosaicExplodeTest.scala @@ -16,5 +16,6 @@ class MosaicExplodeTest extends MosaicSpatialQueryTest with SharedSparkSession w testAllNoCodegen("MosaicExplode column function signatures") { columnFunctionSignatures } testAllNoCodegen("MosaicExplode auxiliary methods") { auxiliaryMethods } testAllNoCodegen("MosaicExplode Line cases identified by issue 360") { issue360 } + testAllNoCodegen("MosaicExplode Should properly handle polygons that are a union of cells") { issue382 } }