Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xarray Dataset Support #1490

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ radiant-mlhub==0.4.1
rarfile==4.0
scikit-image==0.21.0
scipy==1.11.1
xarray
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
zipfile-deflate64==0.2.0
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ radiant-mlhub==0.3.0
rarfile==4.0
scikit-image==0.18.0
scipy==1.6.2
xarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to determine the minimum version that works before merging

zipfile-deflate64==0.2.0

# docs
Expand Down
65 changes: 65 additions & 0 deletions tests/data/rioxarray/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the MIT License.

import os
import shutil

import cftime
import numpy as np
import pandas as pd
import xarray as xr

SIZE = 32

LATS: list[tuple[float]] = [(40, 42), (60, 62), (80, 82)]

LONS: list[tuple[float]] = [(-55, -50), (-60, -55), (-85, 80)]

VAR_NAMES = ["zos", "tos"]

DIR = "data"

CF_TIME = [True, False, True]

NUM_TIME_STEPS = 3


def create_rioxr_dataset(
lat_min: float,
lat_max: float,
lon_min: float,
lon_max: float,
cf_time: bool,
var_name: str,
filename: str,
):
# Generate x and y coordinates
lats = np.linspace(lat_min, lat_max, SIZE)
lons = np.linspace(lon_min, lon_max, SIZE)

if cf_time:
times = [cftime.datetime(2000, 1, i + 1) for i in range(NUM_TIME_STEPS)]
else:
times = pd.date_range(start="2000-01-01", periods=NUM_TIME_STEPS, freq="D")

# data with shape (time, x, y)
data = np.random.rand(len(times), len(lats), len(lons))

# Create the xarray dataset
ds = xr.Dataset(
data_vars={var_name: (("time", "x", "y"), data)},
coords={"x": lats, "y": lons, "time": times},
)
ds.to_netcdf(path=filename)


if __name__ == "__main__":
if os.path.isdir(DIR):
shutil.rmtree(DIR)
os.makedirs(DIR)
for var_name in VAR_NAMES:
for lats, lons, cf_time in zip(LATS, LONS, CF_TIME):
path = os.path.join(DIR, f"{var_name}_{lats}_{lons}.nc")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never thought about the possibility of the filename containing the bounds/res/crs. Have you seen this in the wild? If so, we could add a check for this in the regex and extract it from the filename like we do for time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this was just a dummy naming scheme for test data. Maybe we should explicitly create test data for some of the different data cases like CMIP, ERA5, MODIS and others and write test cases for those.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that once we have subclasses for each of those datasets

create_rioxr_dataset(
lats[0], lats[1], lons[0], lons[1], cf_time, var_name, path
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
43 changes: 43 additions & 0 deletions tests/datasets/test_rioxarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
import torch

from torchgeo.datasets import (
BoundingBox,
IntersectionDataset,
RioXarrayDataset,
UnionDataset,
)

pytest.importorskip("rioxarray")


class TestRioXarrayDataset:
@pytest.fixture(scope="class")
def dataset(self) -> RioXarrayDataset:
root = os.path.join("tests", "data", "rioxarray", "data")
return RioXarrayDataset(root=root, data_variables=["zos", "tos"])

def test_getitem(self, dataset: RioXarrayDataset) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)

def test_and(self, dataset: RioXarrayDataset) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: RioXarrayDataset) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_invalid_query(self, dataset: RioXarrayDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from .potsdam import Potsdam2D
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rioxarray import RioXarrayDataset
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel1, Sentinel2
Expand Down Expand Up @@ -229,6 +230,7 @@
"NonGeoClassificationDataset",
"NonGeoDataset",
"RasterDataset",
"RioXarrayDataset",
"UnionDataset",
"VectorDataset",
# Utilities
Expand Down
156 changes: 156 additions & 0 deletions torchgeo/datasets/rioxarray.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend renaming this file. Otherwise the following become very different things:

import rioxarray
import .rioxarray

Maybe call it rioxr.py? Or just throw it in geo.py with the other base classes.

Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""In-memory geographical xarray.DataArray."""
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

import glob
import os
import re
import sys
from datetime import datetime
from typing import Any, Callable, Optional, cast

import netCDF4 # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used at the moment and could probably be removed

Copy link
Collaborator Author

@nilsleh nilsleh Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running into this issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather ignore that warning in pyproject.toml than add a fake import

import numpy as np
import torch
import xarray as xr
from rasterio.crs import CRS
from rtree.index import Index, Property

from .geo import GeoDataset
from .utils import BoundingBox


class RioXarrayDataset(GeoDataset):
"""Wrapper for geographical datasets stored as Xarray Datasets.

Relies on rioxarray.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""

filename_glob = "*"
filename_regex = ".*"

def __init__(
self,
root: str,
data_variables: list[str],
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
) -> None:
"""Initialize a new Dataset instance.

Args:
root: directory with nc files
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
data_variables: data variables that should be gathered from the xr_datasets
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of dataarray)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the dataarray)
transforms: a function/transform that takes an input sample
and returns a transformed version

Raises:
FileNotFoundError: if no files are found in ``root``
"""
super().__init__(transforms)

self.root = root
self.data_variables = data_variables
self.transforms = transforms

# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))

# Populate the dataset index
i = 0
pathname = os.path.join(root, self.filename_glob)
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for filepath in glob.iglob(pathname, recursive=True):
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
with xr.open_dataset(filepath, decode_times=True) as ds:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
if crs is None:
crs = ds.rio.crs
if res is None:
res = ds.rio.resolution()[0]

(minx, miny, maxx, maxy) = ds.rio.bounds()
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(ds, "time"):
try:
indices = ds.indexes["time"].to_datetimeindex()
except AttributeError:
indices = ds.indexes["time"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a standard attribute name or just something chosen by the authors of your particular NetCDF files? If necessary we can always add a class attribute that lists the names of layers to look for things like time from.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have encountered different naming schemes for spatial coordinate x and y or lat and lon but for the time dimension it was always called time. However, the ds.rio.clip_box() function in the __getitem__ method expects the spatial coordinates to be named x and y otherwise it complains and tells you to rename them to x and y.


mint = indices.min().to_pydatetime().timestamp()
maxt = indices.max().to_pydatetime().timestamp()
else:
mint = 0
maxt = sys.maxsize
coords = (minx, maxx, miny, maxy, mint, maxt)
self.index.insert(i, coords, filepath)
i += 1

if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"
raise FileNotFoundError(msg)

if not crs:
self._crs = "EPSG:4326"
else:
self._crs = cast(CRS, crs)
self.res = cast(float, res)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image/mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
hits = self.index.intersection(tuple(query), objects=True)
items = [hit.object for hit in hits]

if not items:
raise IndexError(
f"query: {query} not found in index with bounds: {self.bounds}"
)

data_arrays: list["np.typing.NDArray[np.float32]"] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will they always be float32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there could be cases where you also have integers but I would expect most datasets to have float values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to keep it dynamic if we can't predict 100% of the time.

for item in items:
with xr.open_dataset(item, decode_cf=True) as ds:
if not ds.rio.crs:
ds.rio.write_crs(self._crs, inplace=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does rioxarray automatically reproject to the right CRS if files are in multiple or if the user chooses a different CRS than the default (or if IntersectionDataset changes it)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good Point, I think it does not automatically reproject. So far I have only used climate data which doesn't explicitly encode or use a CRS, I should check with some MODIS files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am just doing this with the climate data, where there is no explicit CRS and ds.rio.reproject() also only works for 2D and 3D arrays, whereas the CMIP data I have has more dimensions so I get rioxarray.exceptions.TooManyDimensions: Only 2D and 3D data arrays supported. Data variable: tos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, trying it with MODIS files which are .hdf files, I can only open them with rioxarray.open_rasterio() and not xr.open_dataset(engine="rasterio") so maybe one base class is too ambitious and ugly to support climate and satellite data at once.


# clip box ignores time dimension
clipped = ds.rio.clip_box(
minx=query.minx, miny=query.miny, maxx=query.maxx, maxy=query.maxy
)
# select time dimension
if hasattr(ds, "time"):
try:
clipped["time"] = clipped.indexes["time"].to_datetimeindex()
except AttributeError:
clipped["time"] = clipped.indexes["time"]
clipped = clipped.sel(
time=slice(
datetime.fromtimestamp(query.mint),
datetime.fromtimestamp(query.maxt),
)
)
for variable in self.data_variables:
if hasattr(clipped, variable):
data_arrays.append(clipped[variable].data.squeeze())

sample = {"image": torch.from_numpy(np.stack(data_arrays)), "bbox": query}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should support both images and masks. See the is_image attribute in RasterDataset for how we do this there. You can also copy the dtype property to automatically choose what dtype to cast to.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this works correctly for multiple overlapping files. We shouldn't be stacking, we should be merging.

Copy link
Collaborator Author

@nilsleh nilsleh Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, maybe have to rethink the base class thing and have one for xarray climate type data that is not using crs explicitly (class XarrayDataset(GeoDataset)) and one that is intended for crs depending data sources like MODIS and more similar to RasterDataset and using rioxarray so the current naming class RioXarrayDataset(GeoDataset) but with the required functionality to handle overlapping files etc.


if self.transforms is not None:
sample = self.transforms(sample)

return sample