diff --git a/odc/geo/_xr_interop.py b/odc/geo/_xr_interop.py index bbbb32a..9c7ad0b 100644 --- a/odc/geo/_xr_interop.py +++ b/odc/geo/_xr_interop.py @@ -51,8 +51,11 @@ from .masking import ( bits_to_bool, enum_to_bool, + mask_invalid_data, + mask_clouds, + mask_ls, + mask_s2, scale_and_offset, - scale_and_offset_dataset, ) from .overlap import compute_output_geobox from .roi import roi_is_empty @@ -1065,6 +1068,8 @@ def nodata(self, value: Nodata): enum_to_bool = _wrap_op(enum_to_bool) + mask_invalid_data = _wrap_op(mask_invalid_data) + if have.rasterio: write_cog = _wrap_op(write_cog) to_cog = _wrap_op(to_cog) @@ -1105,7 +1110,15 @@ def to_rgba( ) -> xarray.DataArray: return to_rgba(self._xx, bands=bands, vmin=vmin, vmax=vmax) - scale_and_offset = _wrap_op(scale_and_offset_dataset) + scale_and_offset = _wrap_op(scale_and_offset) + + mask_invalid_data = _wrap_op(mask_invalid_data) + + mask_clouds = _wrap_op(mask_clouds) + + mask_ls = _wrap_op(mask_ls) + + mask_s2 = _wrap_op(mask_s2) ODCExtensionDs.to_rgba.__doc__ = to_rgba.__doc__ diff --git a/odc/geo/masking.py b/odc/geo/masking.py index 820c698..23a23a7 100644 --- a/odc/geo/masking.py +++ b/odc/geo/masking.py @@ -6,11 +6,67 @@ Functions around supporting cloud masking. """ +from typing import Annotated, Any, Callable, Sequence +import numpy as np from xarray import DataArray, Dataset +from enum import Enum + + +class SENTINEL2_L2A_SCL(Enum): + """ + Sentinel-2 Scene Classification Layer (SCL) values. + """ + + NO_DATA = 0 + SATURATED_OR_DEFECTIVE = 1 + DARK_AREA_PIXELS = 2 + CLOUD_SHADOWS = 3 + VEGETATION = 4 + NOT_VEGETATED = 5 + WATER = 6 + UNCLASSIFIED = 7 + CLOUD_MEDIUM_PROBABILITY = 8 + CLOUD_HIGH_PROBABILITY = 9 + THIN_CIRRUS = 10 + SNOW = 11 + + +SENTINEL2_L2A_SCALE = 0.0001 +SENTINEL2_L2A_OFFSET = -0.1 + + +class LANDSAT_C2L2_PIXEL_QA(Enum): + """ + Landsat Collection 2 Surface Reflectance Pixel Quality values. + """ + + NO_DATA = 0 + DILATED_CLOUD = 1 + CIRRUS = 2 + CLOUD = 3 + CLOUD_SHADOW = 4 + SNOW = 5 + CLEAR = 6 + WATER = 7 + # Not sure how to implement these yet... + # CLOUD_CONFIDENCE = [8, 9] + # CLOUD_SHADOW_CONFIDENCE = [10, 11] + # SNOW_ICE_CONFIDENCE = [12, 13] + # CIRRUS_CONFIDENCE = [14, 15] + + +LANDSAT_C2L2_SCALE = 0.0000275 +LANDSAT_C2L2_OFFSET = -0.2 + +# TODO: QA_RADSAT and QA_AEROSOL for Landsat Collection 2 Surface Reflectance + def bits_to_bool( - xx: DataArray, bits: list[int] | None, bitflags: int | None, invert: bool = False + xx: DataArray, + bits: Sequence[int] | None, + bitflags: int | None, + invert: bool = False, ) -> DataArray: """ Convert integer array into boolean array using bitmasks. @@ -43,7 +99,9 @@ def bits_to_bool( return mask -def enum_to_bool(xx: DataArray, values: list, invert: bool = False) -> DataArray: +def enum_to_bool( + xx: DataArray, values: Sequence[Any], invert: bool = False +) -> DataArray: """ Convert array into boolean array using a list of invalid values. @@ -62,11 +120,11 @@ def enum_to_bool(xx: DataArray, values: list, invert: bool = False) -> DataArray def scale_and_offset( - xx: DataArray, + xx: DataArray | Dataset, scale: float | None, offset: float | None, - ignore_missing: bool = False, -) -> DataArray: + clip: Annotated[Sequence[int | float], 2] | None = None, +) -> DataArray | Dataset: """ Apply scale and offset to the DataArray. Leave scale and offset blank to use the values from the DataArray's attrs. @@ -77,7 +135,14 @@ def scale_and_offset( :return: DataArray with scaled and offset values """ - # Scales and offsets is used by GDAL. + # For the Dataset case, we do this recursively for all variables. + if type(xx) is Dataset: + for var in xx.data_vars: + xx[var] = scale_and_offset(xx[var], scale, offset, clip=clip) + + return xx + + # "Scales" and "offsets" is used by GDAL. if scale is None: scale = xx.attrs.get("scales") @@ -91,31 +156,186 @@ def scale_and_offset( if offset is None and scale is not None: offset = 0.0 + # Store the nodata values to apply to the result + nodata = xx.odc.nodata + + # Stash the attributes + attrs = {k: v for k, v in xx.attrs.items()} + + if nodata is not None: + nodata_mask = xx == nodata + + # If both are missing, we can just return the original array. if scale is not None and offset is not None: - xx = xx * scale + offset - else: - if not ignore_missing: - raise ValueError( - "Scale and offset not provided and not found in attrs.scales and attrs.offset" - ) + xx = (xx * scale) + offset + + if clip is not None: + assert len(clip) == 2, "Clip must be a list of two values" + xx = xx.clip(clip[0], clip[1]) + + # Re-attach nodata + if nodata is not None: + xx = xx.where(~nodata_mask, other=nodata) + + xx.attrs = attrs # Not sure if this is required + + return xx + + +def mask_invalid_data( + xx: DataArray | Dataset, + nodata: int | float | None = None, + skip_bands: Sequence[str] = [], +) -> DataArray | Dataset: + """ + Mask out invalid data values. + + :param xx: DataArray + :return: DataArray with invalid data values converted to np.nan. Note this will change the dtype to float. + """ + if type(xx) is Dataset: + for var in xx.data_vars: + if var not in skip_bands: + xx[var] = mask_invalid_data(xx[var], nodata) + return xx + + if nodata is None: + nodata = xx.odc.nodata + + assert nodata is not None, "Nodata value must be provided or available in attrs" + + xx = xx.where(xx != nodata) + xx.odc.nodata = np.nan return xx -def scale_and_offset_dataset( - xx: Dataset, scale: float | None, offset: float | None +def mask_clouds( + xx: Dataset, + qa_name: str, + scale: float, + offset: float, + clip: tuple, + mask_func: Callable = enum_to_bool, # Pass the function for enum-based masks (bits_to_bool or enum_to_bool) + mask_func_args: dict = {}, # Pass the arguments for the mask function + apply_mask: bool = True, + keep_qa: bool = False, + return_mask: bool = False, ) -> Dataset: """ - Apply scale and offset to the Dataset. Leave scale and offset blank to use - the values from each DataArray's attrs. + General cloud masking function for both Landsat and Sentinel-2 products. - :param xx: Dataset with integer values - :param scale: Scale factor - :param offset: Offset - :return: Dataset with scaled and offset values + :param xx: Dataset or DataArray + :param qa_name: QA band to use for masking + :param mask_classes: List of mask class values (e.g., cloud, cloud shadow) + :param scale: Scale value for the dataset + :param offset: Offset value for the dataset + :param clip: Clip range for the data + :param includ_cirrus: Whether to include cirrus in the mask + :param apply_mask: Apply the cloud mask to the data, erasing data where clouds are present + :param keep_qa: Keep the QA band in the output + :param return_mask: Return the mask as a variable called "mask" + :param enum_to_bool_func: Function to convert bit values to boolean mask (either bits_to_bool or enum_to_bool) + :return: Dataset or DataArray with invalid data values converted to np.nan. Note this will change the dtype to float. """ + attrs = {k: v for k, v in xx.attrs.items()} + + # Retrieve the QA band + try: + qa = xx[qa_name] + except KeyError: + raise KeyError(f"QA band '{qa_name}' not found in dataset.") + + # Drop the QA band and apply other preprocessing steps + xx = xx.drop_vars(qa_name) + xx = mask_invalid_data(xx) + xx = scale_and_offset(xx, scale=scale, offset=offset, clip=clip) + # Generate the mask + mask = mask_func(qa, **mask_func_args) + + # Apply the mask if required + if apply_mask: + xx = xx.where(~mask) + + # Set 'nodata' to np.nan for all variables for var in xx.data_vars: - xx[var] = scale_and_offset(xx[var], scale, offset, ignore_missing=True) + xx[var].odc.nodata = np.nan - return xx + # Optionally keep the QA band + if keep_qa: + xx[qa_name] = qa + + # Optionally return the mask + if return_mask: + xx["mask"] = mask + + xx.attrs = attrs + + return xx # type: ignore + + +def mask_ls( + xx: Dataset, + qa_name: str = "pixel_qa", + include_cirrus: bool = False, + apply_mask: bool = True, + keep_qa: bool = False, + return_mask: bool = False, +) -> Dataset: + """ + Perform cloud masking for Landsat Collection 2 products. + """ + mask_bits = [ + LANDSAT_C2L2_PIXEL_QA.CLOUD.value, + LANDSAT_C2L2_PIXEL_QA.CLOUD_SHADOW.value, + ] + if include_cirrus: + mask_bits.append(LANDSAT_C2L2_PIXEL_QA.CIRRUS.value) + + return mask_clouds( + xx=xx, + qa_name=qa_name, + scale=LANDSAT_C2L2_SCALE, + offset=LANDSAT_C2L2_OFFSET, + clip=(0.0, 1.0), + mask_func=bits_to_bool, + mask_func_args={"bits": mask_bits}, + apply_mask=apply_mask, + keep_qa=keep_qa, + return_mask=return_mask, + ) + + +def mask_s2( + xx: Dataset, + qa_name: str = "scl", + include_cirrus: bool = False, + apply_mask: bool = True, + keep_qa: bool = False, + return_mask: bool = False, +) -> Dataset: + """ + Perform cloud masking for Sentinel-2 L2A products. + """ + mask_values = [ + SENTINEL2_L2A_SCL.SATURATED_OR_DEFECTIVE.value, + SENTINEL2_L2A_SCL.CLOUD_MEDIUM_PROBABILITY.value, + SENTINEL2_L2A_SCL.CLOUD_HIGH_PROBABILITY.value, + SENTINEL2_L2A_SCL.CLOUD_SHADOWS.value, + ] + if include_cirrus: + mask_values.append(SENTINEL2_L2A_SCL.THIN_CIRRUS.value) + + return mask_clouds( + xx=xx, + qa_name=qa_name, + scale=SENTINEL2_L2A_SCALE, + offset=SENTINEL2_L2A_OFFSET, + mask_func=enum_to_bool, + mask_func_args={"values": mask_values}, + clip=(0.0, 1.0), + apply_mask=apply_mask, + keep_qa=keep_qa, + return_mask=return_mask, + ) diff --git a/tests/test_masking.py b/tests/test_masking.py index 920add6..9d1ce71 100644 --- a/tests/test_masking.py +++ b/tests/test_masking.py @@ -1,4 +1,10 @@ -from odc.geo.masking import bits_to_bool, enum_to_bool, scale_and_offset +import numpy as np +from odc.geo.masking import ( + bits_to_bool, + enum_to_bool, + scale_and_offset, + mask_invalid_data, +) from xarray import DataArray @@ -12,6 +18,9 @@ # values set to 3 (shadow), 9 (high confidence cloud). xx_values = DataArray([[3, 9], [3, 0]], dims=("y", "x")) +# Array with some zeros +xx_with_nodata = DataArray([[1, 2], [0, 0]], dims=("y", "x"), attrs={"nodata": 0}) + # Test bits_to_bool def test_bits_to_bool(): @@ -50,3 +59,12 @@ def test_scale_and_offset(): mask = scale_and_offset(xx_values, scale=2.0, offset=1.0) assert mask.equals(DataArray([[7, 19], [7, 1]], dims=("y", "x"))) + + +# Test mask_invalid +def test_mask_invalid_data(): + mask = mask_invalid_data(xx_with_nodata) + assert mask.equals(DataArray([[1.0, 2.0], [np.nan, np.nan]], dims=("y", "x"))) + + mask = mask_invalid_data(xx_with_nodata, nodata=1) + assert mask.equals(DataArray([[np.nan, 2], [0, 0]], dims=("y", "x")))