Skip to content

Commit

Permalink
refactor: reproject dtype and nodata handling
Browse files Browse the repository at this point in the history
- allow dtype change as part of reprojection
- more consistent handling of destination fill
  value defaults between different implementations
  • Loading branch information
Kirill888 committed Jun 16, 2024
1 parent b6b9ce3 commit d9bd9c9
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 38 deletions.
38 changes: 20 additions & 18 deletions odc/geo/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,13 @@
from ._blocks import BlockAssembler
from .gcp import GCPGeoBox
from .geobox import GeoBox, GeoboxTiles
from .warp import Nodata, Resampling, _rio_reproject, resampling_s2rio


def resolve_fill_value(dst_nodata, src_nodata, dtype):
dtype = np.dtype(dtype)

if dst_nodata is not None:
return dtype.type(dst_nodata)
if src_nodata is not None:
return dtype.type(src_nodata)
if np.issubdtype(dtype, np.floating):
return dtype.type("nan")
return dtype.type(0)
from .warp import (
Nodata,
Resampling,
_rio_reproject,
resampling_s2rio,
resolve_fill_value,
)


def _do_chunked_reproject(
Expand Down Expand Up @@ -50,7 +44,11 @@ def _do_chunked_reproject(
dtype = ba.dtype

dst_shape = ba.with_yx(ba.shape, dst_gbox.shape)
dst = np.zeros(dst_shape, dtype=dtype)
dst = np.full(
dst_shape,
resolve_fill_value(dst_nodata, src_nodata, dtype),
dtype=dtype,
)

for src_roi in ba.planes_yx():
src = ba.extract(src_nodata, dtype=dtype, casting=casting, roi=src_roi)
Expand Down Expand Up @@ -79,6 +77,7 @@ def dask_rio_reproject(
dst_nodata: Nodata = None,
ydim: int = 0,
chunks: Optional[Tuple[int, int]] = None,
dtype=None,
**kwargs,
) -> da.Array:
# pylint: disable=too-many-arguments, too-many-locals
Expand All @@ -92,6 +91,9 @@ def dask_rio_reproject(
def with_yx(a, yx):
return (*a[:ydim], *yx, *a[ydim + 2 :])

if dtype is None:
dtype = src.dtype

name: str = kwargs.pop("name", "reproject")

gbt_src = GeoboxTiles(s_gbox, src.chunks[ydim : ydim + 2])
Expand All @@ -100,6 +102,7 @@ def with_yx(a, yx):

dst_shape = with_yx(src.shape, d_gbox.shape.yx)
dst_chunks: Tuple[Tuple[int, ...], ...] = with_yx(src.chunks, gbt_dst.chunks)
fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype)

tk = uuid4().hex
name = f"{name}-{tk}"
Expand All @@ -111,15 +114,14 @@ def with_yx(a, yx):
gbt_src,
gbt_dst,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
dst_nodata=fill_value,
axis=ydim,
resampling=resampling,
dtype=dtype,
**kwargs,
)
src_block_keys = src.__dask_keys__()

fill_value = resolve_fill_value(dst_nodata, src_nodata, src.dtype)

def _src(idx):
a = src_block_keys
for i in idx:
Expand All @@ -141,4 +143,4 @@ def _src(idx):

dsk = HighLevelGraph.from_collections(name, dsk, dependencies=(src,))

return da.Array(dsk, name, chunks=dst_chunks, dtype=src.dtype, shape=dst_shape)
return da.Array(dsk, name, chunks=dst_chunks, dtype=dtype, shape=dst_shape)
43 changes: 29 additions & 14 deletions odc/geo/_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .overlap import compute_output_geobox
from .roi import roi_is_empty
from .types import Resolution, SomeResolution, SomeShape, xy_
from .warp import resolve_fill_value

# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-lines
Expand All @@ -63,6 +64,8 @@

# these attributes are pruned during reproject
SPATIAL_ATTRIBUTES = ("crs", "crs_wkt", "grid_mapping", "gcps", "epsg")
NODATA_ATTRIBUTES = ("nodata", "_FillValue")
REPROJECT_SKIP_ATTRS: set[str] = set(SPATIAL_ATTRIBUTES + NODATA_ATTRIBUTES)

# dimensions with these names are considered spatial
STANDARD_SPATIAL_DIMS = [
Expand Down Expand Up @@ -654,6 +657,7 @@ def xr_reproject(
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dtype=None,
resolution: Union[SomeResolution, Literal["auto", "fit", "same"]] = "auto",
shape: Union[SomeShape, int, None] = None,
tight: bool = False,
Expand Down Expand Up @@ -728,10 +732,10 @@ def xr_reproject(
}
if isinstance(src, xarray.DataArray):
return _xr_reproject_da(
src, how, resampling=resampling, dst_nodata=dst_nodata, **kw
src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw
)
return _xr_reproject_ds(
src, how, resampling=resampling, dst_nodata=dst_nodata, **kw
src, how, resampling=resampling, dst_nodata=dst_nodata, dtype=dtype, **kw
)


Expand All @@ -750,6 +754,7 @@ def _xr_reproject_ds(
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dtype=None,
**kw,
) -> xarray.Dataset:
assert isinstance(src, xarray.Dataset)
Expand All @@ -776,7 +781,12 @@ def _maybe_reproject(dv: xarray.DataArray):
dv = dv.drop_vars(strip_coords)
return dv
return _xr_reproject_da(
dv, how=dst_geobox, resampling=resampling, dst_nodata=dst_nodata, **kw
dv,
how=dst_geobox,
resampling=resampling,
dst_nodata=dst_nodata,
dtype=dtype,
**kw,
)

return src.map(_maybe_reproject)
Expand All @@ -788,6 +798,7 @@ def _xr_reproject_da(
*,
resampling: Union[str, int] = "nearest",
dst_nodata: Optional[float] = None,
dtype=None,
**kw,
) -> xarray.DataArray:
# pylint: disable=too-many-locals
Expand All @@ -809,6 +820,9 @@ def _xr_reproject_da(
else:
dst_geobox = src.odc.output_geobox(how, **kw_gbox)

if dtype is None:
dtype = src.dtype

# compute destination shape by replacing spatial dimensions shape
ydim = src.odc.ydim
assert ydim + 1 == src.odc.xdim
Expand All @@ -817,8 +831,8 @@ def _xr_reproject_da(
src_nodata = kw.pop("src_nodata", None)
if src_nodata is None:
src_nodata = src.odc.nodata
if dst_nodata is None:
dst_nodata = src_nodata

fill_value = resolve_fill_value(dst_nodata, src_nodata, dtype)

if is_dask_collection(src):
from ._dask import dask_rio_reproject
Expand All @@ -829,12 +843,13 @@ def _xr_reproject_da(
dst_geobox,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
dst_nodata=fill_value,
ydim=ydim,
dtype=dtype,
**kw,
)
else:
dst = numpy.empty(dst_shape, dtype=src.dtype)
dst = numpy.full(dst_shape, fill_value, dtype=dtype)

dst = rio_reproject(
src.values,
Expand All @@ -843,17 +858,17 @@ def _xr_reproject_da(
dst_geobox,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
dst_nodata=fill_value,
ydim=ydim,
dtype=dtype,
**kw,
)

attrs = {k: v for k, v in src.attrs.items() if k not in SPATIAL_ATTRIBUTES}
if dst_nodata is None:
attrs.pop("nodata", None)
attrs.pop("_FillValue", None)
else:
attrs.update(nodata=maybe_int(dst_nodata, 1e-6))
attrs = {k: v for k, v in src.attrs.items() if k not in REPROJECT_SKIP_ATTRS}
if numpy.isfinite(fill_value) and (
dst_nodata is not None or src_nodata is not None
):
attrs.update({k: maybe_int(float(fill_value), 1e-6) for k in NODATA_ATTRIBUTES})

# new set of coords (replace x,y dims)
# discard all coords that reference spatial dimensions
Expand Down
32 changes: 26 additions & 6 deletions odc/geo/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
Nodata = Optional[Union[int, float]]
_WRP_CRS = "epsg:3857"

__all__ = [
"resampling_s2rio",
"is_resampling_nn",
"resolve_fill_value",
"warp_affine",
"warp_affine_rio",
"rio_reproject",
]


def resampling_s2rio(name: str) -> rasterio.warp.Resampling:
"""
Expand All @@ -38,6 +47,18 @@ def is_resampling_nn(resampling: Resampling) -> bool:
return resampling == rasterio.warp.Resampling.nearest


def resolve_fill_value(dst_nodata, src_nodata, dtype):
dtype = np.dtype(dtype)

if dst_nodata is not None:
return dtype.type(dst_nodata)
if np.issubdtype(dtype, np.floating):
return dtype.type("nan")
if src_nodata is not None:
return dtype.type(src_nodata)
return dtype.type(0)


def warp_affine_rio(
src: np.ndarray,
dst: np.ndarray,
Expand Down Expand Up @@ -129,15 +150,14 @@ def rio_reproject(
:returns: dst
"""
assert src.ndim == dst.ndim
if dst_nodata is None:
if dst.dtype.kind == "f":
dst_nodata = np.nan

if src.ndim == 2:
return _rio_reproject(
src, dst, s_gbox, d_gbox, resampling, src_nodata, dst_nodata, **kwargs
)

fill_value = resolve_fill_value(dst_nodata, src_nodata, dst.dtype)

if ydim is None:
# Assume last two dimensions are Y/X
ydim = src.ndim - 2
Expand All @@ -154,9 +174,9 @@ def rio_reproject(
dst[roi],
s_gbox,
d_gbox,
resampling,
src_nodata,
dst_nodata,
resampling=resampling,
src_nodata=src_nodata,
dst_nodata=fill_value,
**kwargs,
)
return dst
Expand Down
3 changes: 3 additions & 0 deletions tests/test_xr_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ def test_xr_reproject(xx_epsg4326: xr.DataArray):
assert xx.odc.geobox == dst_gbox
assert xx.encoding["grid_mapping"] == "spatial_ref"
assert "crs" not in xx.attrs
assert xx.dtype == xx_epsg4326.dtype

assert xx_epsg4326.odc.reproject(3857, dtype="float32").dtype == "float32"

yy = xr.Dataset({"a": xx0, "b": xx0 + 1, "c": xr.DataArray([2, 3, 4])})
assert isinstance(yy.odc, ODCExtensionDs)
Expand Down

0 comments on commit d9bd9c9

Please sign in to comment.