Skip to content

Commit

Permalink
Merge pull request #413 from GispoCoding/404-add-mask-raster-tool
Browse files Browse the repository at this point in the history
404 add mask raster tool
  • Loading branch information
nmaarnio committed Aug 13, 2024
2 parents 16ebee0 + de19c14 commit a4ff8d1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/raster_processing/masking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Masking

::: eis_toolkit.raster_processing.masking
66 changes: 66 additions & 0 deletions eis_toolkit/raster_processing/masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import rasterio
from beartype import beartype
from beartype.typing import Tuple

from eis_toolkit.raster_processing.unifying import unify_raster_grids
from eis_toolkit.utilities.checks.raster import check_raster_grids


@beartype
def mask_raster(
raster: rasterio.io.DatasetReader,
base_raster: rasterio.io.DatasetReader,
) -> Tuple[np.ndarray, rasterio.profiles.Profile]:
"""
Mask input raster using the nodata locations from base raster.
Only the first band of base raster is used to scan for nodata cells. Masking is performed to all
bands of input raster.
If input rasters have mismatching grid properties, unifies rasters before masking (uses `nearest`
resampling, unify separately first if you need control over the resampling method).
Args:
raster: The raster to be masked.
base_raster: The base raster used to determine nodata locations.
Returns:
The masked raster data.
The raster profile.
"""
raster_profile = raster.profile
base_raster_profile = base_raster.profile
profiles = [raster_profile, base_raster_profile]

# Unify if the rasters have different grids
if check_raster_grids(profiles, same_extent=True):
raster_arr = raster.read()
else:
out_rasters = unify_raster_grids(
base_raster=base_raster, rasters_to_unify=[raster], resampling_method="nearest", same_extent=True
)
raster_arr = out_rasters[1][0]

# Update profiles
raster_profile = out_rasters[1][1]
profiles[0] = raster_profile

# Extract nodata info
raster_nodata = raster_profile.get("nodata", np.nan)
base_raster_nodata = base_raster_profile.get("nodata", np.nan)

# Create mask to apply
base_raster_arr = base_raster.read(1)
base_raster_nodata_mask = (base_raster_arr == base_raster_nodata) | np.isnan(base_raster_arr)

# Apply mask to all bands of input raster
bands = raster.count
out_image = np.empty((bands, raster_profile["height"], raster_profile["width"]), dtype=raster_profile["dtype"])
for i in range(bands):
band_arr = raster_arr[i]
band_arr[base_raster_nodata_mask] = raster_nodata
out_image[i] = band_arr

out_profile = raster_profile.copy()
return out_image, out_profile
32 changes: 32 additions & 0 deletions tests/raster_processing/masking_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import tempfile

import geopandas as gpd
import numpy as np
import rasterio

from eis_toolkit.raster_processing.clipping import clip_raster
from eis_toolkit.raster_processing.masking import mask_raster
from tests.raster_processing.clip_test import polygon_path as SMALL_VECTOR_PATH
from tests.raster_processing.clip_test import raster_path as SMALL_RASTER_PATH


def test_mask_raster():
"""Test that masking raster works as intended."""
with rasterio.open(SMALL_RASTER_PATH) as raster:
geodataframe = gpd.read_file(SMALL_VECTOR_PATH)
out_image, out_meta = clip_raster(raster, geodataframe)
with tempfile.NamedTemporaryFile() as tmpfile:
with rasterio.open(tmpfile.name, "w", **out_meta) as dest:
dest.write(out_image)
with rasterio.open(tmpfile.name) as base_raster:
old_nodata_count = np.count_nonzero(raster.read(1) == raster.nodata)
out_image, out_profile = mask_raster(raster, base_raster)

new_nodata_count = np.count_nonzero(out_image == out_profile["nodata"])

# Check nodata count has increased
assert new_nodata_count > old_nodata_count
# Check that nodata exists now in identical locations in input raster and base raster
np.testing.assert_array_equal(
base_raster.read(1) == base_raster.nodata, out_image[0] == out_profile["nodata"]
)

0 comments on commit a4ff8d1

Please sign in to comment.