Skip to content

Commit

Permalink
updated RST_PixelCount to take extra params
Browse files Browse the repository at this point in the history
  • Loading branch information
sllynn committed Sep 24, 2024
1 parent fc7b527 commit 9a0b087
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 30 deletions.
26 changes: 23 additions & 3 deletions docs/source/api/raster-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1737,16 +1737,36 @@ rst_numbands
+---------------------+

rst_pixelcount
***************
**************

.. function:: rst_pixelcount(tile)
.. function:: rst_pixelcount(tile, count_nodata, count_all)

Returns an array containing valid pixel count values for each band.
Returns an array containing pixel count values for each band; default excludes mask and nodata pixels.

:param tile: A column containing the raster tile.
:type tile: Column (RasterTileType)
:param count_nodata: A column to specify whether to count nodata pixels.
:type count_nodata: Column (BooleanType)
:param count_all: A column to specify whether to count all pixels.
:type count_all: Column (BooleanType)
:rtype: Column: ArrayType(LongType)

.. note::

Notes:

If pixel value is noData or mask value is 0.0, the pixel is not counted by default.

:code:`count_nodata`
- This is an optional param.
- if specified as true, include the noData (not mask) pixels in the count (default is false).

:code:`count_all`
- This is an optional param; as a positional arg, must also pass :code:`count_nodata`
(value of :code:`count_nodata` is ignored).
- if specified as true, simply return bandX * bandY in the count (default is false).
..
:example:

.. tabs::
Expand Down
19 changes: 16 additions & 3 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,21 +755,34 @@ def rst_numbands(raster_tile: ColumnOrName) -> Column:
)


def rst_pixelcount(raster_tile: ColumnOrName) -> Column:
def rst_pixelcount(raster_tile: ColumnOrName, count_nodata: Any = False, count_all: Any = False) -> Column:
"""
Parameters
----------
raster_tile : Column (RasterTileType)
Mosaic raster tile struct column.
count_nodata : Column(BooleanType)
If false do not include noData pixels in count (default is false).
count_all : Column(BooleanType)
If true, simply return bandX * bandY (default is false).
Returns
-------
Column (ArrayType(LongType))
Array containing valid pixel count values for each band.
"""

if type(count_nodata) == bool:
count_nodata = lit(count_nodata)

if type(count_all) == bool:
count_all = lit(count_all)

return config.mosaic_context.invoke_function(
"rst_pixelcount", pyspark_to_java_column(raster_tile)
"rst_pixelcount",
pyspark_to_java_column(raster_tile),
pyspark_to_java_column(count_nodata),
pyspark_to_java_column(count_all),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,36 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) {
/**
* Counts the number of pixels in the band. The mask is used to determine
* if a pixel is valid. If pixel value is noData or mask value is 0.0, the
* pixel is not counted.
*
* pixel is not counted by default.
* @param countNoData
* If specified as true, include the noData (default is false).
* @param countAll
* If specified as true, simply return bandX * bandY (default is false).
* @return
* Returns the band's pixel count.
*/
def pixelCount: Int = {
val line = Array.ofDim[Double](band.GetXSize())
val maskLine = Array.ofDim[Double](band.GetXSize())
var count = 0
for (y <- 0 until band.GetYSize()) {
band.ReadRaster(0, y, band.GetXSize(), 1, line)
val maskRead = band.GetMaskBand().ReadRaster(0, y, band.GetXSize(), 1, maskLine)
if (maskRead != gdalconstConstants.CE_None) {
count = count + line.count(_ != noDataValue)
} else {
count = count + line.zip(maskLine).count { case (pixel, mask) => pixel != noDataValue && mask != 0.0 }
def pixelCount(countNoData: Boolean = false, countAll: Boolean = false): Int = {
if (countAll) {
// all pixels returned
band.GetXSize() * band.GetYSize()
} else {
// nodata not included (default)
val line = Array.ofDim[Double](band.GetXSize())
var count = 0
for (y <- 0 until band.GetYSize()) {
band.ReadRaster(0, y, band.GetXSize(), 1, line)
val maskLine = Array.ofDim[Double](band.GetXSize())
val maskRead = band.GetMaskBand().ReadRaster(0, y, band.GetXSize(), 1, maskLine)
if (maskRead != gdalconstConstants.CE_None) {
count = count + line.count(pixel => countNoData || pixel != noDataValue)
} else {
count = count + line.zip(maskLine).count {
case (pixel, mask) => mask != 0.0 && (countNoData || pixel != noDataValue)
}
}
}
count
}
count
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster

import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression
import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
Expand All @@ -11,17 +11,31 @@ import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._

/** Returns an array containing valid pixel count values for each band. */
case class RST_PixelCount(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig)
extends RasterExpression[RST_PixelCount](rasterExpr, returnsRaster = false, expressionConfig)
case class RST_PixelCount(
rasterExpr: Expression,
noDataExpr: Expression,
allExpr: Expression,
expressionConfig: MosaicExpressionConfig)
extends Raster2ArgExpression[RST_PixelCount](rasterExpr, noDataExpr, allExpr, returnsRaster = false, expressionConfig)
with NullIntolerant
with CodegenFallback {

override def dataType: DataType = ArrayType(LongType)

/** Returns an array containing valid pixel count values for each band. */
override def rasterTransform(tile: MosaicRasterTile): Any = {
/**
* Returns an array containing valid pixel count values for each band.
* - default is to exclude nodata and mask pixels.
* - if countNoData specified as true, include the noData (not mask) pixels in the count (default is false).
* - if countAll specified as true, simply return bandX * bandY in the count (default is false). countAll ignores
* countNodData
*/
override def rasterTransform(tile: MosaicRasterTile, arg1: Any, arg2: Any): Any = {
val bandCount = tile.raster.raster.GetRasterCount()
val pixelCount = (1 to bandCount).map(tile.raster.getBand(_).pixelCount)
val countNoData = arg1.asInstanceOf[Boolean]
val countAll = arg2.asInstanceOf[Boolean]
val pixelCount = (1 to bandCount).map(
tile.raster.getBand(_).pixelCount(countNoData, countAll)
)
ArrayData.toArrayData(pixelCount.toArray)
}

Expand All @@ -32,7 +46,7 @@ object RST_PixelCount extends WithExpressionInfo {

override def name: String = "rst_pixelcount"

override def usage: String = "_FUNC_(expr1) - Returns an array containing valid pixel count values for each band."
override def usage: String = "_FUNC_(expr1) - Returns an array containing pixel count values for each band (default excludes nodata and mask)."

override def example: String =
"""
Expand All @@ -42,7 +56,7 @@ object RST_PixelCount extends WithExpressionInfo {
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[RST_PixelCount](1, expressionConfig)
GenericExpressionFactory.getBaseBuilder[RST_PixelCount](3, expressionConfig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
def rst_convolve(raster: Column, kernel: Column): Column = ColumnAdapter(RST_Convolve(raster.expr, kernel.expr, expressionConfig))
def rst_dtmfromgeoms(pointsArray: Column, linesArray: Column, mergeTol: Column, snapTol: Column, origin: Column, xWidth: Column, yWidth: Column, xSize: Column, ySize: Column): Column =
ColumnAdapter(RST_DTMFromGeoms(pointsArray.expr, linesArray.expr, mergeTol.expr, snapTol.expr, origin.expr, xWidth.expr, yWidth.expr, xSize.expr, ySize.expr, expressionConfig))
def rst_pixelcount(raster: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, expressionConfig))
def rst_pixelcount(raster: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, lit(false).expr, lit(false).expr, expressionConfig))
def rst_pixelcount(raster: Column, countNoData: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, countNoData.expr, lit(false).expr, expressionConfig))
def rst_pixelcount(raster: Column, countNoData: Column, countAll: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, countNoData.expr, countAll.expr, expressionConfig))
def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig))
def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column =
ColumnAdapter(RST_DerivedBand(raster.expr, pythonFunc.expr, funcName.expr, expressionConfig))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ trait RST_PixelCountBehaviors extends QueryTest {
.withColumn("tile", rst_tessellate($"tile", lit(3)))
.createOrReplaceTempView("source")

// TODO: modified to 3 args... should this be revisited?
noException should be thrownBy spark.sql("""
|select rst_pixelcount(tile) from source
|select rst_pixelcount(tile,false,false) from source
|""".stripMargin)

noException should be thrownBy rastersInMemory
Expand Down

0 comments on commit 9a0b087

Please sign in to comment.