diff --git a/.github/linters/codespell.txt b/.github/linters/codespell.txt index a20350263c..7ea89413af 100644 --- a/.github/linters/codespell.txt +++ b/.github/linters/codespell.txt @@ -1,3 +1,4 @@ +LOD actualy afterall atmost diff --git a/common/src/main/java/org/apache/sedona/common/FunctionsGeoTools.java b/common/src/main/java/org/apache/sedona/common/FunctionsGeoTools.java index 28bf6b9045..0fa390023e 100644 --- a/common/src/main/java/org/apache/sedona/common/FunctionsGeoTools.java +++ b/common/src/main/java/org/apache/sedona/common/FunctionsGeoTools.java @@ -101,6 +101,28 @@ public static Geometry transformToGivenTarget( } } + /** + * Get the SRID of a CRS from a WKT string + * + * @param crsWKT WKT string for CRS + * @return SRID + */ + public static int wktCRSToSRID(String crsWKT) { + try { + CoordinateReferenceSystem crs = CRS.parseWKT(crsWKT); + int srid = crsToSRID(crs); + if (srid == 0) { + Integer epsgCode = CRS.lookupEpsgCode(crs, true); + if (epsgCode != null) { + srid = epsgCode; + } + } + return srid; + } catch (FactoryException e) { + throw new IllegalArgumentException("Cannot parse CRS WKT", e); + } + } + /** * Get the SRID of a CRS. We use the EPSG code of the CRS if available. * diff --git a/docs/api/sql/Constructor.md b/docs/api/sql/Constructor.md index d84f2c8b84..a1f715a0a7 100644 --- a/docs/api/sql/Constructor.md +++ b/docs/api/sql/Constructor.md @@ -1,49 +1,3 @@ -## Read ESRI Shapefile - -Introduction: Construct a DataFrame from a Shapefile - -Since: `v1.0.0` - -SparkSQL example: - -```scala -var spatialRDD = new SpatialRDD[Geometry] -spatialRDD.rawSpatialRDD = ShapefileReader.readToGeometryRDD(sparkSession.sparkContext, shapefileInputLocation) -var rawSpatialDf = Adapter.toDf(spatialRDD,sparkSession) -rawSpatialDf.createOrReplaceTempView("rawSpatialDf") -var spatialDf = sparkSession.sql(""" - | ST_GeomFromWKT(rddshape), _c1, _c2 - | FROM rawSpatialDf - """.stripMargin) -spatialDf.show() -spatialDf.printSchema() -``` - -!!!note - The path to the shapefile is the path to the folder that contains the .shp file, not the path to the .shp file itself. The file extensions of .shp, .shx, .dbf must be in lowercase. Assume you have a shape file called ==myShapefile==, the path should be `XXX/myShapefile`. The file structure should be like this: - ``` - - shapefile1 - - shapefile2 - - myshapefile - - myshapefile.shp - - myshapefile.shx - - myshapefile.dbf - - myshapefile... - - ... - ``` - -!!!warning - Please make sure you use ==ST_GeomFromWKT== to create Geometry type column otherwise that column cannot be used in SedonaSQL. - -If the file you are reading contains non-ASCII characters you'll need to explicitly set the Spark config before initializing the SparkSession, then you can use `ShapefileReader.readToGeometryRDD`. - -Example: - -```scala -spark.driver.extraJavaOptions -Dsedona.global.charset=utf8 -spark.executor.extraJavaOptions -Dsedona.global.charset=utf8 -``` - ## ST_GeomCollFromText Introduction: Constructs a GeometryCollection from the WKT with the given SRID. If SRID is not provided then it defaults to 0. It returns `null` if the WKT is not a `GEOMETRYCOLLECTION`. diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md index 0b617384be..a04a60d3b8 100644 --- a/docs/tutorial/sql.md +++ b/docs/tutorial/sql.md @@ -459,9 +459,96 @@ root |-- prop0: string (nullable = true) ``` -## Load Shapefile using SpatialRDD +## Load Shapefile -Shapefile can be loaded by SpatialRDD and converted to DataFrame using Adapter. Please read [Load SpatialRDD](rdd.md#create-a-generic-spatialrdd) and [DataFrame <-> RDD](#convert-between-dataframe-and-spatialrdd). +Since v`1.7.0`, Sedona supports loading Shapefile as a DataFrame. + +=== "Scala/Java" + + ```scala + val df = sedona.read.format("shapefile").load("/path/to/shapefile") + ``` + +=== "Java" + + ```java + Dataset df = sedona.read().format("shapefile").load("/path/to/shapefile") + ``` + +=== "Python" + + ```python + df = sedona.read.format("shapefile").load("/path/to/shapefile") + ``` + +The input path can be a directory containing one or multiple shapefiles, or path to a `.shp` file. + +- When the input path is a directory, all shapefiles directly under the directory will be loaded. If you want to load all shapefiles in subdirectories, please specify `.option("recursiveFileLookup", "true")`. +- When the input path is a `.shp` file, that shapefile will be loaded. Sedona will look for sibling files (`.dbf`, `.shx`, etc.) with the same main file name and load them automatically. + +The name of the geometry column is `geometry` by default. You can change the name of the geometry column using the `geometry.name` option. If one of the non-spatial attributes is named "geometry", `geometry.name` must be configured to avoid conflict. + +=== "Scala/Java" + + ```scala + val df = sedona.read.format("shapefile").option("geometry.name", "geom").load("/path/to/shapefile") + ``` + +=== "Java" + + ```java + Dataset df = sedona.read().format("shapefile").option("geometry.name", "geom").load("/path/to/shapefile") + ``` + +=== "Python" + + ```python + df = sedona.read.format("shapefile").option("geometry.name", "geom").load("/path/to/shapefile") + ``` + +Each record in shapefile has a unique record number, that record number is not loaded by default. If you want to include record number in the loaded DataFrame, you can set the `key.name` option to the name of the record number column: + +=== "Scala/Java" + + ```scala + val df = sedona.read.format("shapefile").option("key.name", "FID").load("/path/to/shapefile") + ``` + +=== "Java" + + ```java + Dataset df = sedona.read().format("shapefile").option("key.name", "FID").load("/path/to/shapefile") + ``` + +=== "Python" + + ```python + df = sedona.read.format("shapefile").option("key.name", "FID").load("/path/to/shapefile") + ``` + +The character encoding of string attributes are inferred from the `.cpg` file. If you see garbled values in string fields, you can manually specify the correct charset using the `charset` option. For example: + +=== "Scala/Java" + + ```scala + val df = sedona.read.format("shapefile").option("charset", "UTF-8").load("/path/to/shapefile") + ``` + +=== "Java" + + ```java + Dataset df = sedona.read().format("shapefile").option("charset", "UTF-8").load("/path/to/shapefile") + ``` + +=== "Python" + + ```python + df = sedona.read.format("shapefile").option("charset", "UTF-8").load("/path/to/shapefile") + ``` + +### (Deprecated) Loading Shapefile using SpatialRDD + +If you are using Sedona earlier than v`1.7.0`, you can load shapefiles as SpatialRDD and converted to DataFrame using Adapter. Please read [Load SpatialRDD](rdd.md#create-a-generic-spatialrdd) and [DataFrame <-> RDD](#convert-between-dataframe-and-spatialrdd). ## Load GeoParquet diff --git a/pom.xml b/pom.xml index f71e7b680e..333c7c7c5e 100644 --- a/pom.xml +++ b/pom.xml @@ -709,7 +709,7 @@ spark - 3.2.3 + 3.2 diff --git a/python/tests/sql/test_shapefile.py b/python/tests/sql/test_shapefile.py new file mode 100644 index 0000000000..1565ee6ea8 --- /dev/null +++ b/python/tests/sql/test_shapefile.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import os.path +import datetime + +from tests.test_base import TestBase +from tests.tools import tests_resource + + +class TestShapefile(TestBase): + def test_read_simple(self): + input_location = os.path.join(tests_resource, "shapefiles/polygon") + df = self.spark.read.format("shapefile").load(input_location) + assert df.count() == 10000 + rows = df.take(100) + for row in rows: + assert len(row) == 1 + assert row["geometry"].geom_type in ("Polygon", "MultiPolygon") + + def test_read_osm_pois(self): + input_location = os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp") + df = self.spark.read.format("shapefile").load(input_location) + assert df.count() == 12873 + rows = df.take(100) + for row in rows: + assert len(row) == 5 + assert row["geometry"].geom_type == "Point" + assert isinstance(row['osm_id'], str) + assert isinstance(row['fclass'], str) + assert isinstance(row['name'], str) + assert isinstance(row['code'], int) + + def test_customize_geom_and_key_columns(self): + input_location = os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1") + df = self.spark.read.format("shapefile").option("geometry.name", "geom").option("key.name", "fid").load(input_location) + assert df.count() == 12873 + rows = df.take(100) + for row in rows: + assert len(row) == 6 + assert row["geom"].geom_type == "Point" + assert isinstance(row['fid'], int) + assert isinstance(row['osm_id'], str) + assert isinstance(row['fclass'], str) + assert isinstance(row['name'], str) + assert isinstance(row['code'], int) + + def test_read_multiple_shapefiles(self): + input_location = os.path.join(tests_resource, "shapefiles/datatypes") + df = self.spark.read.format("shapefile").load(input_location) + rows = df.collect() + assert len(rows) == 9 + for row in rows: + id = row['id'] + assert row['aInt'] == id + if id is not None: + assert row['aUnicode'] == "测试" + str(id) + if id < 10: + assert row['aDecimal'] * 10 == id * 10 + id + assert row['aDecimal2'] is None + assert row['aDate'] == datetime.date(2020 + id, id, id) + else: + assert row['aDecimal'] is None + assert row['aDecimal2'] * 100 == id * 100 + id + assert row['aDate'] is None + else: + assert row['aUnicode'] == '' + assert row['aDecimal'] is None + assert row['aDecimal2'] is None + assert row['aDate'] is None diff --git a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/dbf/DbfParseUtil.java b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/dbf/DbfParseUtil.java index 23d5431792..83a6fced89 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/dbf/DbfParseUtil.java +++ b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/dbf/DbfParseUtil.java @@ -137,13 +137,13 @@ public void parseFileHead(DataInputStream inputStream) throws IOException { } /** - * draw raw byte array of effective record + * Parse the next record in the .dbf file * - * @param inputStream - * @return - * @throws IOException + * @param inputStream input stream of .dbf file + * @return a list of fields as their original representation in the dbf file + * @throws IOException if an I/O error occurs */ - public String parsePrimitiveRecord(DataInputStream inputStream) throws IOException { + public List parse(DataInputStream inputStream) throws IOException { if (isDone()) { return null; } @@ -160,50 +160,34 @@ public String parsePrimitiveRecord(DataInputStream inputStream) throws IOExcepti byte[] primitiveBytes = new byte[recordLength]; inputStream.readFully(primitiveBytes); numRecordRead++; // update number of record read - return primitiveToAttributes(ByteBuffer.wrap(primitiveBytes)); + return extractFieldBytes(ByteBuffer.wrap(primitiveBytes)); } - /** - * abstract attributes from primitive bytes according to field descriptors. - * - * @param inputStream - * @return - * @throws IOException - */ - public String primitiveToAttributes(DataInputStream inputStream) throws IOException { - byte[] delimiter = {'\t'}; - Text attributes = new Text(); - for (int i = 0; i < fieldDescriptors.size(); ++i) { - FieldDescriptor descriptor = fieldDescriptors.get(i); + /** Extract attributes from primitive bytes according to field descriptors. */ + private List extractFieldBytes(ByteBuffer buffer) { + int numFields = fieldDescriptors.size(); + List fieldBytesList = new ArrayList<>(numFields); + for (FieldDescriptor descriptor : fieldDescriptors) { byte[] fldBytes = new byte[descriptor.getFieldLength()]; - inputStream.readFully(fldBytes); - // System.out.println(descriptor.getFiledName() + " " + new String(fldBytes)); - byte[] attr = new String(fldBytes).trim().getBytes(); - if (i > 0) { - attributes.append(delimiter, 0, 1); // first attribute doesn't append '\t' - } - attributes.append(attr, 0, attr.length); + buffer.get(fldBytes, 0, fldBytes.length); + fieldBytesList.add(fldBytes); } - String attrs = attributes.toString(); - return attributes.toString(); + return fieldBytesList; } /** * abstract attributes from primitive bytes according to field descriptors. * - * @param buffer - * @return - * @throws IOException + * @param fieldBytesList a list of primitive bytes + * @return string attributes delimited by '\t' */ - public String primitiveToAttributes(ByteBuffer buffer) throws IOException { + public static String fieldBytesToString(List fieldBytesList) { byte[] delimiter = {'\t'}; Text attributes = new Text(); - for (int i = 0; i < fieldDescriptors.size(); ++i) { - FieldDescriptor descriptor = fieldDescriptors.get(i); - byte[] fldBytes = new byte[descriptor.getFieldLength()]; - buffer.get(fldBytes, 0, fldBytes.length); + for (int i = 0; i < fieldBytesList.size(); ++i) { + byte[] fldBytes = fieldBytesList.get(i); String charset = System.getProperty("sedona.global.charset", "default"); - Boolean utf8flag = charset.equalsIgnoreCase("utf8"); + boolean utf8flag = charset.equalsIgnoreCase("utf8"); byte[] attr = utf8flag ? fldBytes : fastParse(fldBytes, 0, fldBytes.length).trim().getBytes(); if (i > 0) { attributes.append(delimiter, 0, 1); // first attribute doesn't append '\t' diff --git a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/shp/PolygonParser.java b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/shp/PolygonParser.java index f1dcd7712c..5e26a37c69 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/shp/PolygonParser.java +++ b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/parseUtils/shp/PolygonParser.java @@ -18,10 +18,9 @@ */ package org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.shp; -import java.io.IOException; import java.util.ArrayList; import java.util.List; -import org.geotools.geometry.jts.coordinatesequence.CoordinateSequences; +import org.locationtech.jts.algorithm.Orientation; import org.locationtech.jts.geom.CoordinateSequence; import org.locationtech.jts.geom.Geometry; import org.locationtech.jts.geom.GeometryFactory; @@ -44,7 +43,6 @@ public PolygonParser(GeometryFactory geometryFactory) { * * @param reader the reader * @return the geometry - * @throws IOException Signals that an I/O exception has occurred. */ @Override public Geometry parseShape(ShapeReader reader) { @@ -72,16 +70,13 @@ public Geometry parseShape(ShapeReader reader) { LinearRing ring = geometryFactory.createLinearRing(csRing); if (shell == null) { shell = ring; - shellsCCW = CoordinateSequences.isCCW(csRing); - } else if (CoordinateSequences.isCCW(csRing) != shellsCCW) { + shellsCCW = Orientation.isCCW(csRing); + } else if (Orientation.isCCW(csRing) != shellsCCW) { holes.add(ring); } else { - if (shell != null) { - Polygon polygon = - geometryFactory.createPolygon(shell, GeometryFactory.toLinearRingArray(holes)); - polygons.add(polygon); - } - + Polygon polygon = + geometryFactory.createPolygon(shell, GeometryFactory.toLinearRingArray(holes)); + polygons.add(polygon); shell = ring; holes.clear(); } diff --git a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/CombineShapeReader.java b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/CombineShapeReader.java index 9414fceab6..c5b2b3fc3d 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/CombineShapeReader.java +++ b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/CombineShapeReader.java @@ -19,24 +19,21 @@ package org.apache.sedona.core.formatMapper.shapefileParser.shapes; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.IntBuffer; import org.apache.commons.io.FilenameUtils; -import org.apache.hadoop.fs.FSDataInputStream; -import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit; import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.log4j.Logger; import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.shp.ShapeType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class CombineShapeReader extends RecordReader { /** dubug logger */ - static final Logger logger = Logger.getLogger(CombineShapeReader.class); + private static final Logger logger = LoggerFactory.getLogger(CombineShapeReader.class); /** suffix of attribute file */ private static final String DBF_SUFFIX = "dbf"; /** suffix of shape record file */ @@ -93,20 +90,7 @@ public void initialize(InputSplit split, TaskAttemptContext context) if (shxSplit != null) { // shape file exists, extract .shp with .shx // first read all indexes into memory - Path filePath = shxSplit.getPath(); - FileSystem fileSys = filePath.getFileSystem(context.getConfiguration()); - FSDataInputStream shxInpuStream = fileSys.open(filePath); - shxInpuStream.skip(24); - int shxFileLength = - shxInpuStream.readInt() * 2 - 100; // get length in bytes, exclude header - // skip following 72 bytes in header - shxInpuStream.skip(72); - byte[] bytes = new byte[shxFileLength]; - // read all indexes into memory, skip first 50 bytes(header) - shxInpuStream.readFully(bytes, 0, bytes.length); - IntBuffer buffer = ByteBuffer.wrap(bytes).asIntBuffer(); - int[] indexes = new int[shxFileLength / 4]; - buffer.get(indexes); + int[] indexes = ShxFileReader.readAll(shxSplit, context); shapeFileReader = new ShapeFileReader(indexes); } else { shapeFileReader = new ShapeFileReader(); // no index, construct with no parameter @@ -122,7 +106,7 @@ public void initialize(InputSplit split, TaskAttemptContext context) } } - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { boolean hasNextShp = shapeFileReader.nextKeyValue(); if (hasDbf) { @@ -132,10 +116,8 @@ public boolean nextKeyValue() throws IOException, InterruptedException { ShapeType curShapeType = shapeFileReader.getCurrentValue().getType(); while (hasNextShp && !curShapeType.isSupported()) { logger.warn( - "[SEDONA] Shapefile type " - + curShapeType.name() - + " is not supported. Skipped this record." - + " Please use QGIS or GeoPandas to convert it to a type listed in ShapeType.java"); + "[SEDONA] Shapefile type {} is not supported. Skipped this record. Please use QGIS or GeoPandas to convert it to a type listed in ShapeType.java", + curShapeType.name()); if (hasDbf) { hasNextDbf = dbfFileReader.nextKeyValue(); } @@ -149,20 +131,20 @@ public boolean nextKeyValue() throws IOException, InterruptedException { new Exception( "shape record loses attributes in .dbf file at ID=" + shapeFileReader.getCurrentKey().getIndex()); - e.printStackTrace(); + logger.warn(e.getMessage(), e); } else if (!hasNextShp && hasNextDbf) { Exception e = new Exception("Redundant attributes in .dbf exists"); - e.printStackTrace(); + logger.warn(e.getMessage(), e); } } return hasNextShp; } - public ShapeKey getCurrentKey() throws IOException, InterruptedException { + public ShapeKey getCurrentKey() { return shapeFileReader.getCurrentKey(); } - public PrimitiveShape getCurrentValue() throws IOException, InterruptedException { + public PrimitiveShape getCurrentValue() { PrimitiveShape value = new PrimitiveShape(shapeFileReader.getCurrentValue()); if (hasDbf && hasNextDbf) { value.setAttributes(dbfFileReader.getCurrentValue()); @@ -170,7 +152,7 @@ public PrimitiveShape getCurrentValue() throws IOException, InterruptedException return value; } - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return shapeFileReader.getProgress(); } diff --git a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/DbfFileReader.java b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/DbfFileReader.java index 78ba9617e2..285e3ae2b0 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/DbfFileReader.java +++ b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/DbfFileReader.java @@ -19,6 +19,7 @@ package org.apache.sedona.core.formatMapper.shapefileParser.shapes; import java.io.IOException; +import java.util.List; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -26,6 +27,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil; +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor; public class DbfFileReader extends org.apache.hadoop.mapreduce.RecordReader { @@ -34,45 +36,60 @@ public class DbfFileReader extends org.apache.hadoop.mapreduce.RecordReader value = null; /** key value of current row */ private ShapeKey key = null; /** generated id of current row */ private int id = 0; - public void initialize(InputSplit split, TaskAttemptContext context) - throws IOException, InterruptedException { + public void initialize(InputSplit split, TaskAttemptContext context) throws IOException { FileSplit fileSplit = (FileSplit) split; Path inputPath = fileSplit.getPath(); FileSystem fileSys = inputPath.getFileSystem(context.getConfiguration()); - inputStream = fileSys.open(inputPath); + FSDataInputStream stream = fileSys.open(inputPath); + initialize(stream); + } + + public void initialize(FSDataInputStream stream) throws IOException { + inputStream = stream; dbfParser = new DbfParseUtil(); dbfParser.parseFileHead(inputStream); } - public boolean nextKeyValue() throws IOException, InterruptedException { + public List getFieldDescriptors() { + return dbfParser.getFieldDescriptors(); + } + + public boolean nextKeyValue() throws IOException { // first check deleted flag - String curbytes = dbfParser.parsePrimitiveRecord(inputStream); - if (curbytes == null) { + List fieldBytesList = dbfParser.parse(inputStream); + if (fieldBytesList == null) { value = null; return false; } else { - value = curbytes; + value = fieldBytesList; key = new ShapeKey(); key.setIndex(id++); return true; } } - public ShapeKey getCurrentKey() throws IOException, InterruptedException { + public ShapeKey getCurrentKey() { return key; } - public String getCurrentValue() throws IOException, InterruptedException { + public List getCurrentFieldBytes() { return value; } - public float getProgress() throws IOException, InterruptedException { + public String getCurrentValue() { + if (value == null) { + return null; + } + return DbfParseUtil.fieldBytesToString(value); + } + + public float getProgress() { return dbfParser.getProgress(); } diff --git a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShapeFileReader.java b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShapeFileReader.java index ab30d15574..2549c0f3fe 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShapeFileReader.java +++ b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShapeFileReader.java @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.shp.ShapeType; import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.shp.ShpFileParser; public class ShapeFileReader extends RecordReader { @@ -36,7 +37,7 @@ public class ShapeFileReader extends RecordReader { private ShapeKey recordKey = null; /** primitive bytes value */ private ShpRecord recordContent = null; - /** inputstream for .shp file */ + /** input stream for .shp file */ private FSDataInputStream shpInputStream = null; /** Iterator of indexes of records */ private int[] indexes; @@ -53,7 +54,7 @@ public ShapeFileReader() {} /** * constructor with index * - * @param indexes + * @param indexes offsets of records in the .shp file */ public ShapeFileReader(int[] indexes) { this.indexes = indexes; @@ -65,28 +66,40 @@ public void initialize(InputSplit split, TaskAttemptContext context) FileSplit fileSplit = (FileSplit) split; Path filePath = fileSplit.getPath(); FileSystem fileSys = filePath.getFileSystem(context.getConfiguration()); - shpInputStream = fileSys.open(filePath); - // assign inputstream to parser and parse file header to init; - parser = new ShpFileParser(shpInputStream); + FSDataInputStream stream = fileSys.open(filePath); + initialize(stream); + } + + public void initialize(FSDataInputStream stream) throws IOException { + shpInputStream = stream; + parser = new ShpFileParser(stream); parser.parseShapeFileHead(); } - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { if (useIndex) { - /** with index, iterate until end and extract bytes with information from indexes */ + /* with index, iterate until end and extract bytes with information from indexes */ if (indexId == indexes.length) { return false; } // check offset, if current offset in inputStream not match with information in shx, move it - if (shpInputStream.getPos() < indexes[indexId] * 2) { - shpInputStream.skip(indexes[indexId] * 2 - shpInputStream.getPos()); + long pos = indexes[indexId] * 2L; + if (shpInputStream.getPos() < pos) { + long skipBytes = pos - shpInputStream.getPos(); + if (shpInputStream.skip(skipBytes) != skipBytes) { + throw new IOException("Failed to seek to the right place in .shp file"); + } } int currentLength = indexes[indexId + 1] * 2 - 4; recordKey = new ShapeKey(); recordKey.setIndex(parser.parseRecordHeadID()); - recordContent = parser.parseRecordPrimitiveContent(currentLength); + if (currentLength >= 0) { + recordContent = parser.parseRecordPrimitiveContent(currentLength); + } else { + // Ignore this index entry + recordContent = new ShpRecord(new byte[0], ShapeType.NULL.getId()); + } indexId += 2; - return true; } else { if (getProgress() >= 1) { return false; @@ -94,19 +107,19 @@ public boolean nextKeyValue() throws IOException, InterruptedException { recordKey = new ShapeKey(); recordKey.setIndex(parser.parseRecordHeadID()); recordContent = parser.parseRecordPrimitiveContent(); - return true; } + return true; } - public ShapeKey getCurrentKey() throws IOException, InterruptedException { + public ShapeKey getCurrentKey() { return recordKey; } - public ShpRecord getCurrentValue() throws IOException, InterruptedException { + public ShpRecord getCurrentValue() { return recordContent; } - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return parser.getProgress(); } diff --git a/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShxFileReader.java b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShxFileReader.java new file mode 100644 index 0000000000..63d66e559d --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/core/formatMapper/shapefileParser/shapes/ShxFileReader.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.core.formatMapper.shapefileParser.shapes; + +import java.io.DataInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.IntBuffer; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; + +public class ShxFileReader { + + public static int[] readAll(InputSplit split, TaskAttemptContext context) throws IOException { + FileSplit fileSplit = (FileSplit) split; + Path inputPath = fileSplit.getPath(); + FileSystem fileSys = inputPath.getFileSystem(context.getConfiguration()); + try (FSDataInputStream stream = fileSys.open(inputPath)) { + return readAll(stream); + } + } + + public static int[] readAll(DataInputStream stream) throws IOException { + if (stream.skip(24) != 24) { + throw new IOException("Failed to skip 24 bytes in .shx file"); + } + int shxFileLength = stream.readInt() * 2 - 100; // get length in bytes, exclude header + // skip following 72 bytes in header + if (stream.skip(72) != 72) { + throw new IOException("Failed to skip 72 bytes in .shx file"); + } + byte[] bytes = new byte[shxFileLength]; + // read all indexes into memory, skip first 50 bytes(header) + stream.readFully(bytes, 0, bytes.length); + IntBuffer buffer = ByteBuffer.wrap(bytes).asIntBuffer(); + int[] indexes = new int[shxFileLength / 4]; + buffer.get(indexes); + return indexes; + } +} diff --git a/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.cpg b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.cpg new file mode 100644 index 0000000000..3ad133c048 --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.cpg @@ -0,0 +1 @@ +UTF-8 \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.dbf b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.dbf new file mode 100644 index 0000000000..c6afba665f Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.dbf differ diff --git a/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.prj b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.prj new file mode 100644 index 0000000000..a30c00a55d --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.prj @@ -0,0 +1 @@ +GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]] \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.qpj b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.qpj new file mode 100644 index 0000000000..5fbc831e74 --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.qpj @@ -0,0 +1 @@ +GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]] diff --git a/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.shp b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.shp new file mode 100644 index 0000000000..f9d9cf06a3 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.shp differ diff --git a/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.shx b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.shx new file mode 100644 index 0000000000..d665243605 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/bad_shx/bad_shx.shx differ diff --git a/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.cpg b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.cpg new file mode 100644 index 0000000000..cd89cb9758 --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.cpg @@ -0,0 +1 @@ +ISO-8859-1 \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.dbf b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.dbf new file mode 100644 index 0000000000..b9993cc477 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.dbf differ diff --git a/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.shp b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.shp new file mode 100644 index 0000000000..ee279704b0 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.shp differ diff --git a/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.shx b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.shx new file mode 100644 index 0000000000..c1fa99eb18 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/contains_null_geom/contains_null_geom.shx differ diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.cpg b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.cpg new file mode 100644 index 0000000000..3ad133c048 --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.cpg @@ -0,0 +1 @@ +UTF-8 \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.dbf b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.dbf new file mode 100644 index 0000000000..f32642785c Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.dbf differ diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.prj b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.prj new file mode 100644 index 0000000000..5ded4bcacb --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.prj @@ -0,0 +1 @@ +GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.shp b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.shp new file mode 100644 index 0000000000..10cca479d4 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.shp differ diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.shx b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.shx new file mode 100644 index 0000000000..969416c787 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/datatypes/datatypes1.shx differ diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.cpg b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.cpg new file mode 100644 index 0000000000..0909932abd --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.cpg @@ -0,0 +1 @@ +GB2312 \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.dbf b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.dbf new file mode 100644 index 0000000000..64dce7cc1b Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.dbf differ diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.prj b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.prj new file mode 100644 index 0000000000..5ded4bcacb --- /dev/null +++ b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.prj @@ -0,0 +1 @@ +GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] \ No newline at end of file diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.shp b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.shp new file mode 100644 index 0000000000..9bc1ad9bb9 Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.shp differ diff --git a/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.shx b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.shx new file mode 100644 index 0000000000..5292febe4b Binary files /dev/null and b/spark/common/src/test/resources/shapefiles/datatypes/datatypes2.shx differ diff --git a/spark/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5f994e203..d2f1d03406 100644 --- a/spark/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/spark-3.0/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 0000000000..7cd6d03a6d --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *
  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 0000000000..3fc5b41eb9 --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = new Path(file.filePath).getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> new Path(file.filePath) + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 0000000000..ba25c92dad --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: FilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 0000000000..ebc02fae85 --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 0000000000..081bc623db --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = new Path(partitionedFile.filePath) + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(FilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } + + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = { + copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 0000000000..80c431f97b --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + Array.empty, + Seq.empty, + Seq.empty) + } +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 0000000000..7db6bb8d1f --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 0000000000..31f746db49 --- /dev/null +++ b/spark/spark-3.0/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromDecimal( + new java.math.BigDecimal(new String(bytes, StandardCharsets.ISO_8859_1).trim)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 0000000000..b1764e6e21 --- /dev/null +++ b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "ID").get.dataType == StringType) + assert(schema.find(_.name == "LOD").get.dataType == LongType) + assert(schema.find(_.name == "Parent_ID").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 20) + var nonNullLods = 0 + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(row.getAs[String]("ID").nonEmpty) + val lodIndex = row.fieldIndex("LOD") + if (!row.isNullAt(lodIndex)) { + assert(row.getAs[Long]("LOD") == 2) + nonNullLods += 1 + } + assert(row.getAs[String]("Parent_ID").nonEmpty) + } + assert(nonNullLods == 17) + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + } +} diff --git a/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5f994e203..d2f1d03406 100644 --- a/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/spark-3.1/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 0000000000..7cd6d03a6d --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *

  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 0000000000..3fc5b41eb9 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = new Path(file.filePath).getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> new Path(file.filePath) + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 0000000000..ba25c92dad --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: FilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 0000000000..ebc02fae85 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 0000000000..081bc623db --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = new Path(partitionedFile.filePath) + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(FilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } + + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = { + copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 0000000000..80c431f97b --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + Array.empty, + Seq.empty, + Seq.empty) + } +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 0000000000..7db6bb8d1f --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 0000000000..31f746db49 --- /dev/null +++ b/spark/spark-3.1/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromDecimal( + new java.math.BigDecimal(new String(bytes, StandardCharsets.ISO_8859_1).trim)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 0000000000..b1764e6e21 --- /dev/null +++ b/spark/spark-3.1/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "ID").get.dataType == StringType) + assert(schema.find(_.name == "LOD").get.dataType == LongType) + assert(schema.find(_.name == "Parent_ID").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 20) + var nonNullLods = 0 + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(row.getAs[String]("ID").nonEmpty) + val lodIndex = row.fieldIndex("LOD") + if (!row.isNullAt(lodIndex)) { + assert(row.getAs[Long]("LOD") == 2) + nonNullLods += 1 + } + assert(row.getAs[String]("Parent_ID").nonEmpty) + } + assert(nonNullLods == 17) + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + } +} diff --git a/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5f994e203..d2f1d03406 100644 --- a/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/spark-3.2/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 0000000000..7cd6d03a6d --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *

  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 0000000000..3fc5b41eb9 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = new Path(file.filePath).getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> new Path(file.filePath) + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 0000000000..ba25c92dad --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: FilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 0000000000..ebc02fae85 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 0000000000..081bc623db --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = new Path(partitionedFile.filePath) + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(FilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } + + def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = { + copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 0000000000..80c431f97b --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + Array.empty, + Seq.empty, + Seq.empty) + } +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 0000000000..7db6bb8d1f --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 0000000000..31f746db49 --- /dev/null +++ b/spark/spark-3.2/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromDecimal( + new java.math.BigDecimal(new String(bytes, StandardCharsets.ISO_8859_1).trim)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 0000000000..b1764e6e21 --- /dev/null +++ b/spark/spark-3.2/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "ID").get.dataType == StringType) + assert(schema.find(_.name == "LOD").get.dataType == LongType) + assert(schema.find(_.name == "Parent_ID").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 20) + var nonNullLods = 0 + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(row.getAs[String]("ID").nonEmpty) + val lodIndex = row.fieldIndex("LOD") + if (!row.isNullAt(lodIndex)) { + assert(row.getAs[Long]("LOD") == 2) + nonNullLods += 1 + } + assert(row.getAs[String]("Parent_ID").nonEmpty) + } + assert(nonNullLods == 17) + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + } +} diff --git a/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5f994e203..d2f1d03406 100644 --- a/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/spark-3.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 0000000000..7cd6d03a6d --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *

  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 0000000000..3fc5b41eb9 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = new Path(file.filePath).getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> new Path(file.filePath) + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 0000000000..ba25c92dad --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: FilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 0000000000..ebc02fae85 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 0000000000..526b6cbee4 --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = new Path(partitionedFile.filePath) + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(FilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 0000000000..e5135e381d --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 0000000000..7db6bb8d1f --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 0000000000..12238870be --- /dev/null +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromString(UTF8String.fromBytes(bytes)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 0000000000..5f1e34bbe2 --- /dev/null +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,727 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "ID").get.dataType == StringType) + assert(schema.find(_.name == "LOD").get.dataType == LongType) + assert(schema.find(_.name == "Parent_ID").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 20) + var nonNullLods = 0 + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(row.getAs[String]("ID").nonEmpty) + val lodIndex = row.fieldIndex("LOD") + if (!row.isNullAt(lodIndex)) { + assert(row.getAs[Long]("LOD") == 2) + nonNullLods += 1 + } + assert(row.getAs[String]("Parent_ID").nonEmpty) + } + assert(nonNullLods == 17) + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + } +} diff --git a/spark/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5f994e203..d2f1d03406 100644 --- a/spark/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/spark-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 0000000000..7cd6d03a6d --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *

  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 0000000000..301d63296f --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = file.filePath.toPath.getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> file.filePath.toPath + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 0000000000..ba25c92dad --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: FilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 0000000000..ebc02fae85 --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 0000000000..f8f4cac2f0 --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = partitionedFile.filePath.toPath + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(FilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 0000000000..e5135e381d --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 0000000000..7db6bb8d1f --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 0000000000..12238870be --- /dev/null +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromString(UTF8String.fromBytes(bytes)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 0000000000..5f1e34bbe2 --- /dev/null +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,727 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "ID").get.dataType == StringType) + assert(schema.find(_.name == "LOD").get.dataType == LongType) + assert(schema.find(_.name == "Parent_ID").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 20) + var nonNullLods = 0 + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(row.getAs[String]("ID").nonEmpty) + val lodIndex = row.fieldIndex("LOD") + if (!row.isNullAt(lodIndex)) { + assert(row.getAs[Long]("LOD") == 2) + nonNullLods += 1 + } + assert(row.getAs[String]("Parent_ID").nonEmpty) + } + assert(nonNullLods == 17) + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + } +} diff --git a/spark/spark-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e5f994e203..d2f1d03406 100644 --- a/spark/spark-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/spark/spark-3.5/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,2 +1,3 @@ org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat org.apache.spark.sql.execution.datasources.v2.geoparquet.metadata.GeoParquetMetadataDataSource +org.apache.sedona.sql.datasources.shapefile.ShapefileDataSource diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala new file mode 100644 index 0000000000..7cd6d03a6d --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileDataSource.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.util.Try + +/** + * A Spark SQL data source for reading ESRI Shapefiles. This data source supports reading the + * following components of shapefiles: + * + *

  • .shp: the main file
  • .dbf: (optional) the attribute file
  • .shx: (optional) the + * index file
  • .cpg: (optional) the code page file
  • .prj: (optional) the projection file + *
+ * + *

The load path can be a directory containing the shapefiles, or a path to the .shp file. If + * the path refers to a .shp file, the data source will also read other components such as .dbf + * and .shx files in the same directory. + */ +class ShapefileDataSource extends FileDataSourceV2 with DataSourceRegister { + + override def shortName(): String = "shapefile" + + override def fallbackFileFormat: Class[_ <: FileFormat] = null + + override protected def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable(tableName, sparkSession, optionsWithoutPaths, paths, None, fallbackFileFormat) + } + + override protected def getTable( + options: CaseInsensitiveStringMap, + schema: StructType): Table = { + val paths = getTransformedPath(options) + val tableName = getTableName(options, paths) + val optionsWithoutPaths = getOptionsWithoutPaths(options) + ShapefileTable( + tableName, + sparkSession, + optionsWithoutPaths, + paths, + Some(schema), + fallbackFileFormat) + } + + private def getTransformedPath(options: CaseInsensitiveStringMap): Seq[String] = { + val paths = getPaths(options) + transformPaths(paths, options) + } + + private def transformPaths( + paths: Seq[String], + options: CaseInsensitiveStringMap): Seq[String] = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + paths.map { pathString => + if (pathString.toLowerCase(Locale.ROOT).endsWith(".shp")) { + // If the path refers to a file, we need to change it to a glob path to support reading + // .dbf and .shx files as well. For example, if the path is /path/to/file.shp, we need to + // change it to /path/to/file.??? + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + val isDirectory = Try(fs.getFileStatus(path).isDirectory).getOrElse(false) + if (isDirectory) { + pathString + } else { + pathString.substring(0, pathString.length - 3) + "???" + } + } else { + pathString + } + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala new file mode 100644 index 0000000000..301d63296f --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReader.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.Path +import org.apache.sedona.common.FunctionsGeoTools +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.DbfFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.PrimitiveShape +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShapeFileReader +import org.apache.sedona.core.formatMapper.shapefileParser.shapes.ShxFileReader +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.logger +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.openStream +import org.apache.sedona.sql.datasources.shapefile.ShapefilePartitionReader.tryOpenStream +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.baseSchema +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.locationtech.jts.geom.GeometryFactory +import org.locationtech.jts.geom.PrecisionModel +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +import java.nio.charset.StandardCharsets +import scala.collection.JavaConverters._ +import java.util.Locale +import scala.util.Try + +class ShapefilePartitionReader( + configuration: Configuration, + partitionedFiles: Array[PartitionedFile], + readDataSchema: StructType, + options: ShapefileReadOptions) + extends PartitionReader[InternalRow] { + + private val partitionedFilesMap: Map[String, Path] = partitionedFiles.map { file => + val fileName = file.filePath.toPath.getName + val extension = FilenameUtils.getExtension(fileName).toLowerCase(Locale.ROOT) + extension -> file.filePath.toPath + }.toMap + + private val cpg = options.charset.orElse { + // No charset option or sedona.global.charset system property specified, infer charset + // from the cpg file. + tryOpenStream(partitionedFilesMap, "cpg", configuration) + .flatMap { stream => + try { + val lineIter = IOUtils.lineIterator(stream, StandardCharsets.UTF_8) + if (lineIter.hasNext) { + Some(lineIter.next().trim()) + } else { + None + } + } finally { + stream.close() + } + } + .orElse { + // Cannot infer charset from cpg file. If sedona.global.charset is set to "utf8", use UTF-8 as + // the default charset. This is for compatibility with the behavior of the RDD API. + val charset = System.getProperty("sedona.global.charset", "default") + val utf8flag = charset.equalsIgnoreCase("utf8") + if (utf8flag) Some("UTF-8") else None + } + } + + private val prj = tryOpenStream(partitionedFilesMap, "prj", configuration).map { stream => + try { + IOUtils.toString(stream, StandardCharsets.UTF_8) + } finally { + stream.close() + } + } + + private val shpReader: ShapeFileReader = { + val reader = tryOpenStream(partitionedFilesMap, "shx", configuration) match { + case Some(shxStream) => + try { + val index = ShxFileReader.readAll(shxStream) + new ShapeFileReader(index) + } finally { + shxStream.close() + } + case None => new ShapeFileReader() + } + val stream = openStream(partitionedFilesMap, "shp", configuration) + reader.initialize(stream) + reader + } + + private val dbfReader = + tryOpenStream(partitionedFilesMap, "dbf", configuration).map { stream => + val reader = new DbfFileReader() + reader.initialize(stream) + reader + } + + private val geometryField = readDataSchema.filter(_.dataType.isInstanceOf[GeometryUDT]) match { + case Seq(geoField) => Some(geoField) + case Seq() => None + case _ => throw new IllegalArgumentException("Only one geometry field is allowed") + } + + private val shpSchema: StructType = { + val dbfFields = dbfReader + .map { reader => + ShapefileUtils.fieldDescriptorsToStructFields(reader.getFieldDescriptors.asScala.toSeq) + } + .getOrElse(Seq.empty) + StructType(baseSchema(options).fields ++ dbfFields) + } + + // projection from shpSchema to readDataSchema + private val projection = { + val expressions = readDataSchema.map { field => + val index = Try(shpSchema.fieldIndex(field.name)).getOrElse(-1) + if (index >= 0) { + val sourceField = shpSchema.fields(index) + val refExpr = BoundReference(index, sourceField.dataType, sourceField.nullable) + if (sourceField.dataType == field.dataType) refExpr + else { + Cast(refExpr, field.dataType) + } + } else { + if (field.nullable) { + Literal(null) + } else { + // This usually won't happen, since all fields of readDataSchema are nullable for most + // of the time. See org.apache.spark.sql.execution.datasources.v2.FileTable#dataSchema + // for more details. + val dbfPath = partitionedFilesMap.get("dbf").orNull + throw new IllegalArgumentException( + s"Field ${field.name} not found in shapefile $dbfPath") + } + } + } + UnsafeProjection.create(expressions) + } + + // Convert DBF field values to SQL values + private val fieldValueConverters: Seq[Array[Byte] => Any] = dbfReader + .map { reader => + reader.getFieldDescriptors.asScala.map { field => + val index = Try(readDataSchema.fieldIndex(field.getFieldName)).getOrElse(-1) + if (index >= 0) { + ShapefileUtils.fieldValueConverter(field, cpg) + } else { (_: Array[Byte]) => + null + } + }.toSeq + } + .getOrElse(Seq.empty) + + private val geometryFactory = prj match { + case Some(wkt) => + val srid = + try { + FunctionsGeoTools.wktCRSToSRID(wkt) + } catch { + case e: Throwable => + val prjPath = partitionedFilesMap.get("prj").orNull + logger.warn(s"Failed to parse SRID from .prj file $prjPath", e) + 0 + } + new GeometryFactory(new PrecisionModel, srid) + case None => new GeometryFactory() + } + + private var currentRow: InternalRow = _ + + override def next(): Boolean = { + if (shpReader.nextKeyValue()) { + val key = shpReader.getCurrentKey + val id = key.getIndex + + val attributesOpt = dbfReader.flatMap { reader => + if (reader.nextKeyValue()) { + val value = reader.getCurrentFieldBytes + Option(value) + } else { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Shape record loses attributes in .dbf file {} at ID={}", dbfPath, id) + None + } + } + + val value = shpReader.getCurrentValue + val geometry = geometryField.flatMap { _ => + if (value.getType.isSupported) { + val shape = new PrimitiveShape(value) + Some(shape.getShape(geometryFactory)) + } else { + logger.warn( + "Shape type {} is not supported, geometry value will be null", + value.getType.name()) + None + } + } + + val attrValues = attributesOpt match { + case Some(fieldBytesList) => + // Convert attributes to SQL values + fieldBytesList.asScala.zip(fieldValueConverters).map { case (fieldBytes, converter) => + converter(fieldBytes) + } + case None => + // No attributes, fill with nulls + Seq.fill(fieldValueConverters.length)(null) + } + + val serializedGeom = geometry.map(GeometryUDT.serialize).orNull + val shpRow = if (options.keyFieldName.isDefined) { + InternalRow.fromSeq(serializedGeom +: key.getIndex +: attrValues.toSeq) + } else { + InternalRow.fromSeq(serializedGeom +: attrValues.toSeq) + } + currentRow = projection(shpRow) + true + } else { + dbfReader.foreach { reader => + if (reader.nextKeyValue()) { + val dbfPath = partitionedFilesMap.get("dbf").orNull + logger.warn("Redundant attributes in {} exists", dbfPath) + } + } + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + dbfReader.foreach(_.close()) + shpReader.close() + } +} + +object ShapefilePartitionReader { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefilePartitionReader]) + + private def openStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): FSDataInputStream = { + tryOpenStream(partitionedFilesMap, extension, configuration).getOrElse { + val path = partitionedFilesMap.head._2 + val baseName = FilenameUtils.getBaseName(path.getName) + throw new IllegalArgumentException( + s"No $extension file found for shapefile $baseName in ${path.getParent}") + } + } + + private def tryOpenStream( + partitionedFilesMap: Map[String, Path], + extension: String, + configuration: Configuration): Option[FSDataInputStream] = { + partitionedFilesMap.get(extension).map { path => + val fs = path.getFileSystem(configuration) + fs.open(path) + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala new file mode 100644 index 0000000000..ba25c92dad --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefilePartitionReaderFactory.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.PartitionReaderWithPartitionValues +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +case class ShapefilePartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: ShapefileReadOptions, + filters: Seq[Filter]) + extends PartitionReaderFactory { + + private def buildReader( + partitionedFiles: Array[PartitionedFile]): PartitionReader[InternalRow] = { + val fileReader = + new ShapefilePartitionReader( + broadcastedConf.value.value, + partitionedFiles, + readDataSchema, + options) + new PartitionReaderWithPartitionValues( + fileReader, + readDataSchema, + partitionSchema, + partitionedFiles.head.partitionValues) + } + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + partition match { + case filePartition: FilePartition => buildReader(filePartition.files) + case _ => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala new file mode 100644 index 0000000000..ebc02fae85 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileReadOptions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Options for reading Shapefiles. + * @param geometryFieldName + * The name of the geometry field. + * @param keyFieldName + * The name of the shape key field. + * @param charset + * The charset of non-spatial attributes. + */ +case class ShapefileReadOptions( + geometryFieldName: String, + keyFieldName: Option[String], + charset: Option[String]) + +object ShapefileReadOptions { + def parse(options: CaseInsensitiveStringMap): ShapefileReadOptions = { + val geometryFieldName = options.getOrDefault("geometry.name", "geometry") + val keyFieldName = + if (options.containsKey("key.name")) Some(options.get("key.name")) else None + val charset = if (options.containsKey("charset")) Some(options.get("charset")) else None + ShapefileReadOptions(geometryFieldName, keyFieldName, charset) + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala new file mode 100644 index 0000000000..f8f4cac2f0 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScan.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.execution.datasources.FilePartition +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.sedona.sql.datasources.shapefile.ShapefileScan.logger +import org.apache.spark.util.SerializableConfiguration +import org.slf4j.{Logger, LoggerFactory} + +import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable + +case class ShapefileScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression] = Seq.empty, + dataFilters: Seq[Expression] = Seq.empty) + extends FileScan { + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + ShapefilePartitionReaderFactory( + sparkSession.sessionState.conf, + broadcastedConf, + dataSchema, + readDataSchema, + readPartitionSchema, + ShapefileReadOptions.parse(options), + pushedFilters) + } + + override def planInputPartitions(): Array[InputPartition] = { + // Simply use the default implementation to compute input partitions for all files + val allFilePartitions = super.planInputPartitions().flatMap { + case filePartition: FilePartition => + filePartition.files + case partition => + throw new IllegalArgumentException( + s"Unexpected partition type: ${partition.getClass.getCanonicalName}") + } + + // Group shapefiles by their main path (without the extension) + val shapefileGroups: mutable.Map[String, mutable.Map[String, PartitionedFile]] = + mutable.Map.empty + allFilePartitions.foreach { partitionedFile => + val path = partitionedFile.filePath.toPath + val fileName = path.getName + val pos = fileName.lastIndexOf('.') + if (pos == -1) None + else { + val mainName = fileName.substring(0, pos) + val extension = fileName.substring(pos + 1).toLowerCase(Locale.ROOT) + if (ShapefileUtils.shapeFileExtensions.contains(extension)) { + val key = new Path(path.getParent, mainName).toString + val group = shapefileGroups.getOrElseUpdate(key, mutable.Map.empty) + group += (extension -> partitionedFile) + } + } + } + + // Create a partition for each group + shapefileGroups.zipWithIndex.flatMap { case ((key, group), index) => + // Check if the group has all the necessary files + val suffixes = group.keys.toSet + val hasMissingFiles = ShapefileUtils.mandatoryFileExtensions.exists { suffix => + if (!suffixes.contains(suffix)) { + logger.warn(s"Shapefile $key is missing a $suffix file") + true + } else false + } + if (!hasMissingFiles) { + Some(FilePartition(index, group.values.toArray)) + } else { + None + } + }.toArray + } +} + +object ShapefileScan { + val logger: Logger = LoggerFactory.getLogger(classOf[ShapefileScan]) +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala new file mode 100644 index 0000000000..e5135e381d --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileScanBuilder.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class ShapefileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + + override def build(): Scan = { + ShapefileScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedDataFilters, + partitionFilters, + dataFilters) + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala new file mode 100644 index 0000000000..7db6bb8d1f --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileTable.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.hadoop.fs.FileStatus +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.DbfParseUtil +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.TableCapability +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.sedona.sql.datasources.shapefile.ShapefileUtils.{baseSchema, fieldDescriptorsToSchema, mergeSchemas} +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +import java.util.Locale +import scala.collection.JavaConverters._ + +case class ShapefileTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def formatName: String = "Shapefile" + + override def capabilities: java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) None + else { + def isDbfFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".dbf") + } + + def isShpFile(file: FileStatus): Boolean = { + val name = file.getPath.getName.toLowerCase(Locale.ROOT) + name.endsWith(".shp") + } + + if (!files.exists(isShpFile)) None + else { + val readOptions = ShapefileReadOptions.parse(options) + val resolver = sparkSession.sessionState.conf.resolver + val dbfFiles = files.filter(isDbfFile) + if (dbfFiles.isEmpty) { + Some(baseSchema(readOptions, Some(resolver))) + } else { + val serializableConf = new SerializableConfiguration( + sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap)) + val partiallyMergedSchemas = sparkSession.sparkContext + .parallelize(dbfFiles) + .mapPartitions { iter => + val schemas = iter.map { stat => + val fs = stat.getPath.getFileSystem(serializableConf.value) + val stream = fs.open(stat.getPath) + try { + val dbfParser = new DbfParseUtil() + dbfParser.parseFileHead(stream) + val fieldDescriptors = dbfParser.getFieldDescriptors + fieldDescriptorsToSchema(fieldDescriptors.asScala.toSeq, readOptions, resolver) + } finally { + stream.close() + } + }.toSeq + mergeSchemas(schemas).iterator + } + .collect() + mergeSchemas(partiallyMergedSchemas) + } + } + } + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + ShapefileScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = null +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala new file mode 100644 index 0000000000..04ac3bdff9 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/shapefile/ShapefileUtils.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.datasources.shapefile + +import org.apache.sedona.core.formatMapper.shapefileParser.parseUtils.dbf.FieldDescriptor +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis.Resolver +import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.DateType +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.util.Locale + +object ShapefileUtils { + + /** + * shp: main file for storing shapes shx: index file for the main file dbf: attribute file cpg: + * code page file prj: projection file + */ + val shapeFileExtensions: Set[String] = Set("shp", "shx", "dbf", "cpg", "prj") + + /** + * The mandatory file extensions for a shapefile. We don't require the dbf file and shx file for + * being consistent with the behavior of the RDD API ShapefileReader.readToGeometryRDD + */ + val mandatoryFileExtensions: Set[String] = Set("shp") + + def mergeSchemas(schemas: Seq[StructType]): Option[StructType] = { + if (schemas.isEmpty) { + None + } else { + var mergedSchema = schemas.head + schemas.tail.foreach { schema => + try { + mergedSchema = mergeSchema(mergedSchema, schema) + } catch { + case cause: IllegalArgumentException => + throw new IllegalArgumentException( + s"Failed to merge schema $mergedSchema with $schema", + cause) + } + } + Some(mergedSchema) + } + } + + private def mergeSchema(schema1: StructType, schema2: StructType): StructType = { + // The field names are case insensitive when performing schema merging + val fieldMap = schema1.fields.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap + var newFields = schema1.fields + schema2.fields.foreach { f => + fieldMap.get(f.name.toLowerCase(Locale.ROOT)) match { + case Some(existingField) => + if (existingField.dataType != f.dataType) { + throw new IllegalArgumentException( + s"Failed to merge fields ${existingField.name} and ${f.name} because they have different data types: ${existingField.dataType} and ${f.dataType}") + } + case _ => + newFields :+= f + } + } + StructType(newFields) + } + + def fieldDescriptorsToStructFields(fieldDescriptors: Seq[FieldDescriptor]): Seq[StructField] = { + fieldDescriptors.map { desc => + val name = desc.getFieldName + val dataType = desc.getFieldType match { + case 'C' => StringType + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) LongType + else { + val precision = desc.getFieldLength + DecimalType(precision, scale) + } + case 'L' => BooleanType + case 'D' => DateType + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + StructField(name, dataType, nullable = true) + } + } + + def fieldDescriptorsToSchema(fieldDescriptors: Seq[FieldDescriptor]): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + StructType(structFields) + } + + def fieldDescriptorsToSchema( + fieldDescriptors: Seq[FieldDescriptor], + options: ShapefileReadOptions, + resolver: Resolver): StructType = { + val structFields = fieldDescriptorsToStructFields(fieldDescriptors) + val geometryFieldName = options.geometryFieldName + if (structFields.exists(f => resolver(f.name, geometryFieldName))) { + throw new IllegalArgumentException( + s"Field name $geometryFieldName is reserved for geometry but appears in non-spatial attributes. " + + "Please specify a different field name for geometry using the 'geometry.name' option.") + } + options.keyFieldName.foreach { name => + if (structFields.exists(f => resolver(f.name, name))) { + throw new IllegalArgumentException( + s"Field name $name is reserved for shape key but appears in non-spatial attributes. " + + "Please specify a different field name for shape key using the 'key.name' option.") + } + } + StructType(baseSchema(options, Some(resolver)).fields ++ structFields) + } + + def baseSchema(options: ShapefileReadOptions, resolver: Option[Resolver] = None): StructType = { + options.keyFieldName match { + case Some(name) => + if (resolver.exists(_(name, options.geometryFieldName))) { + throw new IllegalArgumentException(s"geometry.name and key.name cannot be the same") + } + StructType( + Seq(StructField(options.geometryFieldName, GeometryUDT), StructField(name, LongType))) + case _ => + StructType(StructField(options.geometryFieldName, GeometryUDT) :: Nil) + } + } + + def fieldValueConverter(desc: FieldDescriptor, cpg: Option[String]): Array[Byte] => Any = { + desc.getFieldType match { + case 'C' => + val encoding = cpg.getOrElse("ISO-8859-1") + if (encoding.toLowerCase(Locale.ROOT) == "utf-8") { (bytes: Array[Byte]) => + UTF8String.fromBytes(bytes).trimRight() + } else { (bytes: Array[Byte]) => + { + val str = new String(bytes, encoding) + UTF8String.fromString(str).trimRight() + } + } + case 'N' | 'F' => + val scale = desc.getFieldDecimalCount + if (scale == 0) { (bytes: Array[Byte]) => + try { + new String(bytes, StandardCharsets.ISO_8859_1).trim.toLong + } catch { + case _: Exception => null + } + } else { (bytes: Array[Byte]) => + try { + Decimal.fromString(UTF8String.fromBytes(bytes)) + } catch { + case _: Exception => null + } + } + case 'L' => + (bytes: Array[Byte]) => + if (bytes.isEmpty) null + else { + bytes.head match { + case 'T' | 't' | 'Y' | 'y' => true + case 'F' | 'f' | 'N' | 'n' => false + case _ => null + } + } + case 'D' => + (bytes: Array[Byte]) => { + try { + val dateString = new String(bytes, StandardCharsets.ISO_8859_1) + val formatter = DateTimeFormatter.BASIC_ISO_DATE + val date = LocalDate.parse(dateString, formatter) + date.toEpochDay.toInt + } catch { + case _: Exception => null + } + } + case _ => + throw new IllegalArgumentException(s"Unsupported field type ${desc.getFieldType}") + } + } +} diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala new file mode 100644 index 0000000000..b1764e6e21 --- /dev/null +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/ShapefileTests.scala @@ -0,0 +1,739 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.{DateType, DecimalType, LongType, StringType, StructField, StructType} +import org.locationtech.jts.geom.{Geometry, MultiPolygon, Point, Polygon} +import org.scalatest.BeforeAndAfterAll + +import java.io.File +import java.nio.file.Files + +class ShapefileTests extends TestBaseScala with BeforeAndAfterAll { + val temporaryLocation: String = resourceFolder + "shapefiles/tmp" + + override def beforeAll(): Unit = { + super.beforeAll() + FileUtils.deleteDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation).toPath) + } + + override def afterAll(): Unit = FileUtils.deleteDirectory(new File(temporaryLocation)) + + describe("Shapefile read tests") { + it("read gis_osm_pois_free_1") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + assert(shapefileDf.count == 12873) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4326) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + // with projection, selecting geometry and attribute fields + shapefileDf.select("geometry", "code").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Long]("code") > 0) + } + + // with projection, selecting geometry fields + shapefileDf.select("geometry").take(10).foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + } + + // with projection, selecting attribute fields + shapefileDf.select("code", "osm_id").take(10).foreach { row => + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("osm_id").nonEmpty) + } + + // with transformation + shapefileDf + .selectExpr("ST_Buffer(geometry, 0.001) AS geom", "code", "osm_id as id") + .take(10) + .foreach { row => + assert(row.getAs[Geometry]("geom").isInstanceOf[Polygon]) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("id").nonEmpty) + } + } + + it("read dbf") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/dbf") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + + shapefileDf.collect().foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.getSRID == 0) + assert(geom.isInstanceOf[Polygon] || geom.isInstanceOf[MultiPolygon]) + assert(row.getAs[String]("STATEFP").nonEmpty) + assert(row.getAs[String]("COUNTYFP").nonEmpty) + assert(row.getAs[String]("COUNTYNS").nonEmpty) + assert(row.getAs[String]("AFFGEOID").nonEmpty) + assert(row.getAs[String]("GEOID").nonEmpty) + assert(row.getAs[String]("NAME").nonEmpty) + assert(row.getAs[String]("LSAD").nonEmpty) + assert(row.getAs[Long]("ALAND") > 0) + assert(row.getAs[Long]("AWATER") >= 0) + } + } + + it("read multipleshapefiles") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/multipleshapefiles") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "STATEFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYFP").get.dataType == StringType) + assert(schema.find(_.name == "COUNTYNS").get.dataType == StringType) + assert(schema.find(_.name == "AFFGEOID").get.dataType == StringType) + assert(schema.find(_.name == "GEOID").get.dataType == StringType) + assert(schema.find(_.name == "NAME").get.dataType == StringType) + assert(schema.find(_.name == "LSAD").get.dataType == StringType) + assert(schema.find(_.name == "ALAND").get.dataType == LongType) + assert(schema.find(_.name == "AWATER").get.dataType == LongType) + assert(schema.length == 10) + assert(shapefileDf.count() == 3220) + } + + it("read missing") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/missing") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "a").get.dataType == StringType) + assert(schema.find(_.name == "b").get.dataType == StringType) + assert(schema.find(_.name == "c").get.dataType == StringType) + assert(schema.find(_.name == "d").get.dataType == StringType) + assert(schema.find(_.name == "e").get.dataType == StringType) + assert(schema.length == 7) + val rows = shapefileDf.collect() + assert(rows.length == 3) + rows.foreach { row => + val a = row.getAs[String]("a") + val b = row.getAs[String]("b") + val c = row.getAs[String]("c") + val d = row.getAs[String]("d") + val e = row.getAs[String]("e") + if (a.isEmpty) { + assert(b == "First") + assert(c == "field") + assert(d == "is") + assert(e == "empty") + } else if (e.isEmpty) { + assert(a == "Last") + assert(b == "field") + assert(c == "is") + assert(d == "empty") + } else { + assert(a == "Are") + assert(b == "fields") + assert(c == "are") + assert(d == "not") + assert(e == "empty") + } + } + } + + it("read unsupported") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/unsupported") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "ID").get.dataType == StringType) + assert(schema.find(_.name == "LOD").get.dataType == LongType) + assert(schema.find(_.name == "Parent_ID").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 20) + var nonNullLods = 0 + rows.foreach { row => + assert(row.getAs[Geometry]("geometry") == null) + assert(row.getAs[String]("ID").nonEmpty) + val lodIndex = row.fieldIndex("LOD") + if (!row.isNullAt(lodIndex)) { + assert(row.getAs[Long]("LOD") == 2) + nonNullLods += 1 + } + assert(row.getAs[String]("Parent_ID").nonEmpty) + } + assert(nonNullLods == 17) + } + + it("read bad_shx") { + var shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/bad_shx") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "field_1").get.dataType == LongType) + var rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + + // Copy the .shp and .dbf files to temporary location, and read the same shapefiles without .shx + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.shp"), + new File(temporaryLocation + "/bad_shx.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/bad_shx/bad_shx.dbf"), + new File(temporaryLocation + "/bad_shx.dbf")) + shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + rows = shapefileDf.collect() + assert(rows.length == 2) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + if (geom == null) { + assert(row.getAs[Long]("field_1") == 3) + } else { + assert(geom.isInstanceOf[Point]) + assert(row.getAs[Long]("field_1") == 2) + } + } + } + + it("read contains_null_geom") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/contains_null_geom") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "fInt").get.dataType == LongType) + assert(schema.find(_.name == "fFloat").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "fString").get.dataType == StringType) + assert(schema.length == 4) + val rows = shapefileDf.collect() + assert(rows.length == 10) + rows.foreach { row => + val fInt = row.getAs[Long]("fInt") + val fFloat = row.getAs[java.math.BigDecimal]("fFloat").doubleValue() + val fString = row.getAs[String]("fString") + val geom = row.getAs[Geometry]("geometry") + if (fInt == 2 || fInt == 5) { + assert(geom == null) + } else { + assert(geom.isInstanceOf[Point]) + assert(geom.getCoordinate.x == fInt) + assert(geom.getCoordinate.y == fInt) + } + assert(Math.abs(fFloat - 3.14159 * fInt) < 1e-4) + assert(fString == s"str_$fInt") + } + } + + it("read test_datatypes") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 7) + + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 4269) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + if (id < 10) { + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.isNullAt(row.fieldIndex("aDecimal2"))) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } else { + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } + } + } + } + + it("read with .shp path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes1.shp") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal").get.dataType.isInstanceOf[DecimalType]) + assert(schema.find(_.name == "aDate").get.dataType == DateType) + assert(schema.length == 6) + + val rows = shapefileDf.collect() + assert(rows.length == 5) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val idIndex = row.fieldIndex("id") + if (row.isNullAt(idIndex)) { + assert(row.isNullAt(row.fieldIndex("aInt"))) + assert(row.getAs[String]("aUnicode").isEmpty) + assert(row.isNullAt(row.fieldIndex("aDecimal"))) + assert(row.isNullAt(row.fieldIndex("aDate"))) + } else { + val id = row.getLong(idIndex) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal")).doubleValue() + assert((decimal * 10).toInt == id * 10 + id) + assert(row.getAs[java.sql.Date]("aDate").toString == s"202$id-0$id-0$id") + } + } + } + + it("read with glob path specified") { + val shapefileDf = sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/datatypes/datatypes2.*") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "aInt").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + assert(schema.find(_.name == "aDecimal2").get.dataType.isInstanceOf[DecimalType]) + assert(schema.length == 5) + + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read without shx") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(geom.getSRID == 0) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + } + + it("read without dbf") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shp")) + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.length == 1) + + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + } + } + + it("read without shp") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.dbf"), + new File(temporaryLocation + "/gis_osm_pois_free_1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx"), + new File(temporaryLocation + "/gis_osm_pois_free_1.shx")) + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .count() + } + + intercept[Exception] { + sparkSession.read + .format("shapefile") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shx") + .count() + } + } + + it("read directory containing missing .shp files") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + // Missing .shp file for datatypes1 + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read partitioned directory") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part=1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part=2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part=1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part=1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part=1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part=2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part=2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part=2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .load(temporaryLocation) + .select("part", "id", "aInt", "aUnicode", "geometry") + var rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id < 10) { + assert(row.getAs[Int]("part") == 1) + } else { + assert(row.getAs[Int]("part") == 2) + } + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + // Using partition filters + rows = shapefileDf.where("part = 2").collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + assert(row.getAs[Int]("part") == 2) + val id = row.getAs[Long]("id") + assert(id > 10) + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + + it("read with recursiveFileLookup") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + Files.createDirectory(new File(temporaryLocation + "/part1").toPath) + Files.createDirectory(new File(temporaryLocation + "/part2").toPath) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.shp"), + new File(temporaryLocation + "/part1/datatypes1.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.dbf"), + new File(temporaryLocation + "/part1/datatypes1.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes1.cpg"), + new File(temporaryLocation + "/part1/datatypes1.cpg")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/part2/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/part2/datatypes2.dbf")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.cpg"), + new File(temporaryLocation + "/part2/datatypes2.cpg")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("recursiveFileLookup", "true") + .load(temporaryLocation) + .select("id", "aInt", "aUnicode", "geometry") + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + if (id > 0) { + assert(row.getAs[String]("aUnicode") == s"测试$id") + } + } + } + + it("read with custom geometry column name") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "geom") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geom").get.dataType == GeometryUDT) + assert(schema.find(_.name == "osm_id").get.dataType == StringType) + assert(schema.find(_.name == "code").get.dataType == LongType) + assert(schema.find(_.name == "fclass").get.dataType == StringType) + assert(schema.find(_.name == "name").get.dataType == StringType) + assert(schema.length == 5) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geom") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.getAs[Long]("code") > 0) + assert(row.getAs[String]("fclass").nonEmpty) + assert(row.getAs[String]("name") != null) + } + + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "osm_id") + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + } + assert( + exception.getMessage.contains( + "osm_id is reserved for geometry but appears in non-spatial attributes")) + } + + it("read with shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "geometry", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "geometry").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with both custom geometry column and shape key column") { + val shapefileDf = sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "fid") + .load(resourceFolder + "shapefiles/datatypes") + .select("id", "fid", "g", "aUnicode") + val schema = shapefileDf.schema + assert(schema.find(_.name == "g").get.dataType == GeometryUDT) + assert(schema.find(_.name == "id").get.dataType == LongType) + assert(schema.find(_.name == "fid").get.dataType == LongType) + assert(schema.find(_.name == "aUnicode").get.dataType == StringType) + val rows = shapefileDf.collect() + assert(rows.length == 9) + rows.foreach { row => + val geom = row.getAs[Geometry]("g") + assert(geom.isInstanceOf[Point]) + val id = row.getAs[Long]("id") + if (id > 0) { + assert(row.getAs[Long]("fid") == id % 10) + assert(row.getAs[String]("aUnicode") == s"测试$id") + } else { + assert(row.getAs[Long]("fid") == 5) + } + } + } + + it("read with invalid shape key column") { + val exception = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "aDate") + .load(resourceFolder + "shapefiles/datatypes") + } + assert( + exception.getMessage.contains( + "aDate is reserved for shape key but appears in non-spatial attributes")) + + val exception2 = intercept[Exception] { + sparkSession.read + .format("shapefile") + .option("geometry.name", "g") + .option("key.name", "g") + .load(resourceFolder + "shapefiles/datatypes") + } + assert(exception2.getMessage.contains("geometry.name and key.name cannot be the same")) + } + + it("read with custom charset") { + FileUtils.cleanDirectory(new File(temporaryLocation)) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.shp"), + new File(temporaryLocation + "/datatypes2.shp")) + FileUtils.copyFile( + new File(resourceFolder + "shapefiles/datatypes/datatypes2.dbf"), + new File(temporaryLocation + "/datatypes2.dbf")) + + val shapefileDf = sparkSession.read + .format("shapefile") + .option("charset", "GB2312") + .load(temporaryLocation) + val rows = shapefileDf.collect() + assert(rows.length == 4) + rows.foreach { row => + assert(row.getAs[Geometry]("geometry").isInstanceOf[Point]) + val id = row.getAs[Long]("id") + assert(row.getAs[Long]("aInt") == id) + assert(row.getAs[String]("aUnicode") == s"测试$id") + val decimal = row.getDecimal(row.fieldIndex("aDecimal2")).doubleValue() + assert((decimal * 100).toInt == id * 100 + id) + } + } + + it("read with custom schema") { + val customSchema = StructType( + Seq( + StructField("osm_id", StringType), + StructField("code2", LongType), + StructField("geometry", GeometryUDT))) + val shapefileDf = sparkSession.read + .format("shapefile") + .schema(customSchema) + .load(resourceFolder + "shapefiles/gis_osm_pois_free_1") + assert(shapefileDf.schema == customSchema) + val rows = shapefileDf.collect() + assert(rows.length == 12873) + rows.foreach { row => + val geom = row.getAs[Geometry]("geometry") + assert(geom.isInstanceOf[Point]) + assert(row.getAs[String]("osm_id").nonEmpty) + assert(row.isNullAt(row.fieldIndex("code2"))) + } + } + } +}