Skip to content

Commit

Permalink
fix(raster_processing): fix unsupported masking case by combining par…
Browse files Browse the repository at this point in the history
…ameters
  • Loading branch information
nmaarnio committed Aug 13, 2024
1 parent 9e56ea7 commit 6efc64a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 30 deletions.
15 changes: 11 additions & 4 deletions eis_toolkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ class ThresholdCriteria(str, Enum):
outside = "outside"


class MaskingMode(str, Enum):
"""Masking modes for raster unification."""

extents = "extents"
full = "full"
none = "none"


INPUT_FILE_OPTION = Annotated[
Path,
typer.Option(
Expand Down Expand Up @@ -1342,8 +1350,7 @@ def unify_rasters_cli(
base_raster: INPUT_FILE_OPTION,
output_directory: OUTPUT_DIR_OPTION,
resampling_method: Annotated[ResamplingMethods, typer.Option(case_sensitive=False)] = ResamplingMethods.nearest,
unify_extents: bool = True,
mask_nodata: bool = False,
masking: Annotated[MaskingMode, typer.Option(case_sensitive=False)] = MaskingMode.extents,
):
"""Unify rasters to match the base raster."""
from eis_toolkit.raster_processing.unifying import unify_raster_grids
Expand All @@ -1354,12 +1361,12 @@ def unify_rasters_cli(
to_unify = [rasterio.open(rstr) for rstr in rasters_to_unify] # Open all rasters to be unified
typer.echo("Progress: 25%")

masking_param = get_enum_values(masking)
unified = unify_raster_grids(
base_raster=raster,
rasters_to_unify=to_unify,
resampling_method=get_enum_values(resampling_method),
unify_extents=unify_extents,
mask_nodata=mask_nodata,
masking=None if masking_param == "none" else masking_param,
)
[rstr.close() for rstr in to_unify] # Close all rasters
typer.echo("Progress: 75%")
Expand Down
6 changes: 1 addition & 5 deletions eis_toolkit/raster_processing/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ def mask_raster(
raster_arr = raster.read()
else:
out_rasters = unify_raster_grids(
base_raster=base_raster,
rasters_to_unify=[raster],
resampling_method="nearest",
unify_extents=True,
mask_nodata=False,
base_raster=base_raster, rasters_to_unify=[raster], resampling_method="nearest", masking="extents"
)
raster_arr = out_rasters[1][0]

Expand Down
21 changes: 9 additions & 12 deletions eis_toolkit/raster_processing/unifying.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def _unify_raster_grids(
base_raster: rasterio.io.DatasetReader,
rasters_to_unify: Sequence[rasterio.io.DatasetReader],
resampling_method: Resampling,
unify_extens: bool,
mask_nodata: bool,
masking: Optional[Literal["extents", "extents_and_nodata"]],
) -> List[Tuple[np.ndarray, Profile]]:

dst_crs = base_raster.crs
Expand All @@ -89,7 +88,7 @@ def _unify_raster_grids(

# If we unify without clipping, things are more complicated and we need to
# calculate corner coordinates, width and height, and snap the grid to nearest corner
if not unify_extens:
if not masking:
dst_transform, dst_width, dst_height = _calculate_snapped_grid(raster, dst_crs, dst_resolution)

dst_array = np.empty((base_raster.count, dst_height, dst_width))
Expand All @@ -115,7 +114,7 @@ def _unify_raster_grids(
resampling=resampling_method,
)[0]

if mask_nodata:
if masking == "full":
_mask_nodata(out_image, nodata, base_raster_arr, base_raster_profile)

out_profile.update({"transform": dst_transform, "width": dst_width, "height": dst_height, "crs": dst_crs})
Expand All @@ -130,8 +129,7 @@ def unify_raster_grids(
base_raster: rasterio.io.DatasetReader,
rasters_to_unify: Sequence[rasterio.io.DatasetReader],
resampling_method: Literal["nearest", "bilinear", "cubic", "average", "gauss", "max", "min"] = "nearest",
unify_extents: bool = True,
mask_nodata: bool = False,
masking: Optional[Literal["extents", "full"]] = "extents",
) -> List[Tuple[np.ndarray, Profile]]:
"""Unifies given rasters with the base raster.
Expand All @@ -147,11 +145,10 @@ def unify_raster_grids(
rasters_to_unify: Rasters to be unified with the base raster.
resampling_method: Resampling method. Most suitable method depends on the dataset and context.
`nearest`, `bilinear` and `cubic` are some common choices. This parameter defaults to `nearest`.
unify_extents: Controls if the bounds of rasters to-be-unified should be modified. If True, will clip
larger rasters and expand smaller rasters (with nodata) to match extents of the base raster.
Defaults to True.
mask_nodata: Whether nodata cells should be copied from the baster raster. If True, copies nodata pixels
to all bands of all rasters to-be-unified. Defaults to False.
masking: Controls if and how masking should be handled. If `extents`, the bounds of rasters to-be-unified
are matched with the base raster. Larger rasters are clipped and smaller rasters expanded (with nodata).
If `full`, copies nodata pixel locations from the base raster additionally. If None,
extents are not matched and nodata not copied. Defaults to `extents`.
Returns:
List of unified rasters' data and profiles. First element is the base raster.
Expand All @@ -163,5 +160,5 @@ def unify_raster_grids(
raise InvalidParameterValueException("Rasters to unify is empty.")

method = RESAMPLE_METHOD_MAP[resampling_method]
out_rasters = _unify_raster_grids(base_raster, rasters_to_unify, method, unify_extents, mask_nodata)
out_rasters = _unify_raster_grids(base_raster, rasters_to_unify, method, masking)
return out_rasters
12 changes: 3 additions & 9 deletions tests/raster_processing/unify_raster_grids_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def test_unify_raster_grids():

with rasterio.open(raster_to_unify_path_1) as raster_to_unify:
with rasterio.open(base_raster_path_1) as base_raster:
out_rasters = unify_raster_grids(
base_raster, [raster_to_unify], "nearest", unify_extents=False, mask_nodata=False
)
out_rasters = unify_raster_grids(base_raster, [raster_to_unify], "nearest", masking=None)
out_image, out_meta = out_rasters[1]

assert len(out_rasters) == 2
Expand Down Expand Up @@ -62,9 +60,7 @@ def test_unify_raster_grids_extent():

with rasterio.open(raster_to_unify_path_2) as raster_to_unify:
with rasterio.open(base_raster_path_2) as base_raster:
out_rasters = unify_raster_grids(
base_raster, [raster_to_unify], "bilinear", unify_extents=True, mask_nodata=False
)
out_rasters = unify_raster_grids(base_raster, [raster_to_unify], "bilinear", masking="extents")
out_image, out_meta = out_rasters[1]

assert len(out_rasters) == 2
Expand Down Expand Up @@ -94,9 +90,7 @@ def test_unify_raster_grids_full_masking():

with rasterio.open(raster_to_unify_path_1) as raster_to_unify:
with rasterio.open(base_raster_path_3) as base_raster:
out_rasters = unify_raster_grids(
base_raster, [raster_to_unify], "bilinear", unify_extents=True, mask_nodata=True
)
out_rasters = unify_raster_grids(base_raster, [raster_to_unify], "bilinear", masking="full")
out_image, out_profile = out_rasters[1]

assert len(out_rasters) == 2
Expand Down

0 comments on commit 6efc64a

Please sign in to comment.