Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/boundary dataloader #90

Open
wants to merge 108 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 88 commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
5df1bff
add datastore_boundary to neural_lam
sadamov Nov 18, 2024
46590ef
complete integration of boundary in weatherDataset
sadamov Nov 18, 2024
b990f49
Add test to check timestep length and spacing
sadamov Nov 18, 2024
3fd1d6b
setting default mdp boundary to 0 gridcells
sadamov Nov 18, 2024
1f2499c
implement time-based slicing
sadamov Nov 18, 2024
1af1481
remove all interior_mask and boundary_mask
sadamov Nov 19, 2024
d545cb7
added gcsfs dependency for era5 weatherbench download
sadamov Nov 19, 2024
5c1a7d7
added new era5 datastore config for boundary
sadamov Nov 19, 2024
30e4f05
removed left-over boundary-mask references
sadamov Nov 19, 2024
6a8c593
make check for existing category in datastore more flexible (for boun…
sadamov Nov 19, 2024
17c920d
implement xarray based (mostly) time slicing and windowing
sadamov Nov 20, 2024
7919995
cleanup analysis based time-slicing
sadamov Nov 21, 2024
9bafcee
implement datastore_boundary in existing tests
sadamov Nov 19, 2024
ce06bbc
allow for grid shape retrieval from forcing data
sadamov Nov 21, 2024
884b5c6
rearrange time slicing, boundary first
sadamov Nov 21, 2024
5904cbe
identified issue, cleanup next
leifdenby Nov 25, 2024
efe0302
use xarray plot only
leifdenby Nov 26, 2024
a489c2e
don't reraise
leifdenby Nov 26, 2024
242d08b
remove debug plot
leifdenby Nov 26, 2024
c1f706c
remove extent calc used in diagnosing issue
leifdenby Nov 26, 2024
cf8e3e4
add type annotation
leifdenby Nov 29, 2024
85160ce
ensure tensor copy to cpu mem before data-array creation
leifdenby Nov 29, 2024
52c4528
apply time-indexing to support ar_steps_val > 1
leifdenby Nov 29, 2024
b96d8eb
renaming test datastores
sadamov Nov 30, 2024
72da25f
adding num_past/future_boundary_step args
sadamov Nov 30, 2024
244f1cc
using combined config file
sadamov Nov 30, 2024
a9cc36e
proper handling of state/forcing/boundary in dataset
sadamov Nov 30, 2024
dcc0b46
datastore_boundars=None introduced
sadamov Nov 30, 2024
a3b3bde
bug fix for file retrieval per member
sadamov Nov 30, 2024
3ffc413
rename datastore for tests
sadamov Nov 30, 2024
85aad66
aligned time with danra for easier boundary testing
sadamov Nov 30, 2024
64f057f
Fixed test for temporal embedding
sadamov Nov 30, 2024
6205dbd
pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard
leifdenby Dec 2, 2024
551cd26
allow boundary as input to ar_model.common_step
sadamov Dec 2, 2024
fc95350
linting
sadamov Dec 2, 2024
01fa807
improved docstrings and added some assertions
sadamov Dec 2, 2024
5a749f3
update mdp dependency
sadamov Dec 2, 2024
45ba607
remove boundary datastore from tests that don't need it
sadamov Dec 2, 2024
f36f360
fix scope of _get_slice_time
sadamov Dec 2, 2024
105108e
fix scope of _get_time_step
sadamov Dec 2, 2024
d760145
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 2, 2024
ae0cf76
added information about optional boundary datastore
sadamov Dec 2, 2024
9af27e0
add datastore_boundary to neural_lam
sadamov Nov 18, 2024
c25fb30
complete integration of boundary in weatherDataset
sadamov Nov 18, 2024
505ceeb
Add test to check timestep length and spacing
sadamov Nov 18, 2024
e733066
setting default mdp boundary to 0 gridcells
sadamov Nov 18, 2024
d8349a4
implement time-based slicing
sadamov Nov 18, 2024
fd791bf
remove all interior_mask and boundary_mask
sadamov Nov 19, 2024
ae82cdb
added gcsfs dependency for era5 weatherbench download
sadamov Nov 19, 2024
34a6cc7
added new era5 datastore config for boundary
sadamov Nov 19, 2024
2dc67a0
removed left-over boundary-mask references
sadamov Nov 19, 2024
9f8628e
make check for existing category in datastore more flexible (for boun…
sadamov Nov 19, 2024
388c79d
implement xarray based (mostly) time slicing and windowing
sadamov Nov 20, 2024
2529969
cleanup analysis based time-slicing
sadamov Nov 21, 2024
179a035
implement datastore_boundary in existing tests
sadamov Nov 19, 2024
2daeb16
allow for grid shape retrieval from forcing data
sadamov Nov 21, 2024
cbcdcae
rearrange time slicing, boundary first
sadamov Nov 21, 2024
e6ace27
renaming test datastores
sadamov Nov 30, 2024
42818f0
adding num_past/future_boundary_step args
sadamov Nov 30, 2024
0103b6e
using combined config file
sadamov Nov 30, 2024
0896344
proper handling of state/forcing/boundary in dataset
sadamov Nov 30, 2024
355423c
datastore_boundars=None introduced
sadamov Nov 30, 2024
121d460
bug fix for file retrieval per member
sadamov Nov 30, 2024
7e82eef
rename datastore for tests
sadamov Nov 30, 2024
320d7c4
aligned time with danra for easier boundary testing
sadamov Nov 30, 2024
f18dcc2
Fixed test for temporal embedding
sadamov Nov 30, 2024
e6327d8
allow boundary as input to ar_model.common_step
sadamov Dec 2, 2024
1374a19
linting
sadamov Dec 2, 2024
779f3e9
improved docstrings and added some assertions
sadamov Dec 2, 2024
f126ec2
remove boundary datastore from tests that don't need it
sadamov Dec 2, 2024
4b656da
fix scope of _get_time_step
sadamov Dec 2, 2024
75db4b8
added information about optional boundary datastore
sadamov Dec 2, 2024
58b4af6
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 2, 2024
4c17545
moved gcsfs to dev group
sadamov Dec 3, 2024
a700350
linting
sadamov Dec 3, 2024
16d5d04
Fixed issue with temporal encoding dimensions
sadamov Dec 3, 2024
f1f3f73
format docstrings
sadamov Dec 3, 2024
8fd7a10
introduced time slicing test for forecast type data
sadamov Dec 3, 2024
252a33c
bugfix temporal embedding dimension
sadamov Dec 3, 2024
8a9114a
linting
sadamov Dec 3, 2024
8c7709a
switched to low-res data
sadamov Dec 3, 2024
24cbf13
add datastore_boundary as explicit attribute
sadamov Dec 3, 2024
1d53ce7
fixing up forecast type data tests,
sadamov Dec 5, 2024
cfe1e27
time step can and should be retrieved in __init__
sadamov Dec 5, 2024
e4e4e37
Fix dataset issue in npy stat script
joeloskarsson Dec 4, 2024
3df3fcb
Merge remote-tracking branch 'mllam/main' into feat/boundary_dataloader
sadamov Dec 5, 2024
f8613da
added static feature to era5 boundary test datastore
sadamov Dec 5, 2024
f0a7046
Merge remote-tracking branch 'mllam/main' into feat/boundary_dataloader
sadamov Dec 6, 2024
8cc608d
rename function to represent multiple datastores
sadamov Dec 20, 2024
857f748
streamline da_grid_reference variable naming
sadamov Dec 20, 2024
d0a6f24
updated docstring of WeatherDataset
sadamov Dec 20, 2024
ef40a39
renamed da_boundary -> da_boundary_forcing
sadamov Dec 20, 2024
71b52b2
updated docstrings of get_dataarray()
sadamov Dec 20, 2024
b690563
check times in stateless functions from utils.py
sadamov Dec 20, 2024
a37dc3c
add num_ensemble_members property to BaseDatastore
sadamov Dec 20, 2024
8d1bec6
Update neural_lam/weather_dataset.py
sadamov Dec 20, 2024
47370f9
renaming time_diff_steps to time_deltas
sadamov Dec 20, 2024
7e1a246
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 20, 2024
d524377
add num_ensemble_members to mdp store
sadamov Dec 20, 2024
98c54d9
Rename temporal embeddings and diffs to time deltas
sadamov Dec 20, 2024
4a278fd
Adding some comments about analysis_time indexing
sadamov Dec 20, 2024
c82d22b
moved comments around
sadamov Dec 20, 2024
6e3f3bd
Make hotfix to make boundary dataset created with mdp work
joeloskarsson Dec 19, 2024
20ca263
Bugfixes
sadamov Dec 20, 2024
c0c50d5
sadamov Dec 20, 2024
94de240
Add missing check if boundary_forcing is None
sadamov Dec 20, 2024
1d14a15
bugfix typo in time check
sadamov Dec 20, 2024
7e5797e
introduce crop_time_if_needed to align interior with boundary data
sadamov Dec 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th
interface that provides the data in a data-structure that can be used within
neural-lam. A datastore is used to create a `pytorch.Dataset`-derived
class that samples the data in time to create individual samples for
training, validation and testing.
training, validation and testing. A secondary datastore can be provided
for the boundary data. Currently, boundary datastore must be of type `mdp`
and only contain forcing features. This can easily be expanded in the future.

2. **The graph structure** is used to define message-passing GNN layers,
that are trained to emulate fluid flow in the atmosphere over time. The
Expand All @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model.

The path you provide to the neural-lam config (`config.yaml`) also sets the
root directory relative to which all other paths are resolved, as in the parent
directory of the config becomes the root directory. Both the datastore and
directory of the config becomes the root directory. Both the datastores and
graphs you generate are then stored in subdirectories of this root directory.
Exactly how and where a specific datastore expects its source data to be stored
and where it stores its derived data is up to the implementation of the
Expand All @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`):
data/
├── config.yaml - Configuration file for neural-lam
├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml
├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml
└── graphs/ - Directory containing graphs for training
```

Expand All @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like:
datastore:
kind: mdp
config_path: danra.datastore.yaml
datastore_boundary:
kind: mdp
config_path: era5.datastore.yaml
training:
state_feature_weighting:
__config_class__: ManualStateFeatureWeighting
values:
weights:
u100m: 1.0
v100m: 1.0
```

For now the neural-lam config only defines two things: 1) the kind of data
store and the path to its config, and 2) the weighting of different features in
the loss function. If you don't define the state feature weighting it will default
to weighting all features equally.
For now the neural-lam config only defines two things:
1) the kind of datastores and the path to their config
2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally.

(This example is taken from the `tests/datastore_examples/mdp` directory.)

Expand Down Expand Up @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha

# Contact
If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch.
There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots).
You can also open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]).
There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]).
17 changes: 16 additions & 1 deletion neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,15 @@ class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
----------
datastore : DatastoreSelection
The configuration for the datastore to use.
datastore_boundary : Union[DatastoreSelection, None]
The configuration for the boundary datastore to use, if any. If None,
no boundary datastore is used.
training : TrainingConfig
The configuration for training the model.
"""

datastore: DatastoreSelection
datastore_boundary: Union[DatastoreSelection, None] = None
training: TrainingConfig = dataclasses.field(default_factory=TrainingConfig)

class _(dataclass_wizard.JSONWizard.Meta):
Expand Down Expand Up @@ -168,4 +172,15 @@ def load_config_and_datastore(
datastore_kind=config.datastore.kind, config_path=datastore_config_path
)

return config, datastore
if config.datastore_boundary is not None:
sadamov marked this conversation as resolved.
Show resolved Hide resolved
datastore_boundary_config_path = (
Path(config_path).parent / config.datastore_boundary.config_path
)
datastore_boundary = init_datastore(
datastore_kind=config.datastore_boundary.kind,
config_path=datastore_boundary_config_path,
)
else:
datastore_boundary = None

return config, datastore, datastore_boundary
17 changes: 0 additions & 17 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,23 +228,6 @@ def get_dataarray(
"""
pass

@cached_property
@abc.abstractmethod
def boundary_mask(self) -> xr.DataArray:
"""
Return the boundary mask for the dataset, with spatial dimensions
stacked. Where the value is 1, the grid point is a boundary point, and
where the value is 0, the grid point is not a boundary point.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions
`('grid_index',)`.

"""
pass

@abc.abstractmethod
def get_xy(self, category: str) -> np.ndarray:
"""
Expand Down
70 changes: 22 additions & 48 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore):

SHORT_NAME = "mdp"

def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
def __init__(self, config_path, reuse_existing=True):
"""
Construct a new MDPDatastore from the configuration file at
`config_path`. A boundary mask is created with `n_boundary_points`
boundary points. If `reuse_existing` is True, the dataset is loaded
`config_path`. If `reuse_existing` is True, the dataset is loaded
from a zarr file if it exists (unless the config has been modified
since the zarr was created), otherwise it is created from the
configuration file.
Expand All @@ -42,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
The path to the configuration file, this will be fed to the
`mllam_data_prep.Config.from_yaml_file` method to then call
`mllam_data_prep.create_dataset` to create the dataset.
n_boundary_points : int
The number of boundary points to use in the boundary mask.
reuse_existing : bool
Whether to reuse an existing dataset zarr file if it exists and its
creation date is newer than the configuration file.
Expand All @@ -70,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
if self._ds is None:
self._ds = mdp.create_dataset(config=self._config)
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
Expand Down Expand Up @@ -158,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]:
The units of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []
return self._ds[f"{category}_feature_units"].values.tolist()

Expand All @@ -177,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]:
The names of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []
return self._ds[f"{category}_feature"].values.tolist()

Expand All @@ -197,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]:
The long names of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []
return self._ds[f"{category}_feature_long_name"].values.tolist()

Expand Down Expand Up @@ -253,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
The xarray DataArray object with processed dataset.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
return None
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []

da_category = self._ds[category]

Expand Down Expand Up @@ -319,37 +315,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
return ds_stats

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""
Produce a 0/1 mask for the boundary points of the dataset, these will
sit at the edges of the domain (in x/y extent) and will be used to mask
out the boundary points from the loss function and to overwrite the
boundary points from the prediction. For now this is created when the
mask is requested, but in the future this could be saved to the zarr
file.

Returns
-------
xr.DataArray
A 0/1 mask for the boundary points of the dataset, where 1 is a
boundary point and 0 is not.

"""
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
da_state_variable = (
ds_unstacked["state"].isel(time=0).isel(state_feature=0)
)
da_domain_allzero = xr.zeros_like(da_state_variable)
ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
x=slice(self._n_boundary_points, -self._n_boundary_points),
y=slice(self._n_boundary_points, -self._n_boundary_points),
)
ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
1
).astype(int)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)

@property
def coords_projection(self) -> ccrs.Projection:
"""
Expand Down Expand Up @@ -415,8 +380,17 @@ def grid_shape_state(self):
The shape of the cartesian grid for the state variables.

"""
ds_state = self.unstack_grid_coords(self._ds["state"])
da_x, da_y = ds_state.x, ds_state.y
# Boundary data often has no state features
if "state" not in self._ds:
warnings.warn(
"no state data found in datastore"
"returning grid shape from forcing data"
)
ds_forcing = self.unstack_grid_coords(self._ds["forcing"])
sadamov marked this conversation as resolved.
Show resolved Hide resolved
da_x, da_y = ds_forcing.x, ds_forcing.y
else:
ds_state = self.unstack_grid_coords(self._ds["state"])
da_x, da_y = ds_state.x, ds_state.y
assert da_x.ndim == da_y.ndim == 1
return CartesianGridShape(x=da_x.size, y=da_y.size)

Expand Down
sadamov marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def main(
ar_steps = 63
ds = WeatherDataset(
datastore=datastore,
datastore_boundary=None,
split="train",
ar_steps=ar_steps,
standardize=False,
Expand Down Expand Up @@ -201,7 +202,7 @@ def main(
print("Computing mean and std.-dev. for parameters...")
means, squares, flux_means, flux_squares = [], [], [], []

for init_batch, target_batch, forcing_batch, _ in tqdm(loader):
for init_batch, target_batch, forcing_batch, _, _ in tqdm(loader):
if distributed:
init_batch, target_batch, forcing_batch = (
init_batch.to(device),
Expand Down Expand Up @@ -275,6 +276,7 @@ def main(
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
datastore=datastore,
datastore_boundary=None,
split="train",
ar_steps=ar_steps,
standardize=True,
Expand Down Expand Up @@ -303,7 +305,7 @@ def main(

diff_means, diff_squares = [], []

for init_batch, target_batch, _, _ in tqdm(
for init_batch, target_batch, _, _, _ in tqdm(
loader_standard, disable=rank != 0
):
if distributed:
Expand Down
42 changes: 11 additions & 31 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ def _get_single_timeseries_dataarray(
* np.timedelta64(1, "h")
)
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
coord_values = self._get_analysis_times(
split=split, member_id=member
)
elif d == "y":
coord_values = y
elif d == "x":
Expand Down Expand Up @@ -505,23 +507,29 @@ def _get_single_timeseries_dataarray(

return da

def _get_analysis_times(self, split) -> List[np.datetime64]:
def _get_analysis_times(self, split, member_id) -> List[np.datetime64]:
"""Get the analysis times for the given split by parsing the filenames
of all the files found for the given split.

Parameters
----------
split : str
The dataset split to get the analysis times for.
member_id : int
The ensemble member to get the analysis times for.

Returns
-------
List[dt.datetime]
The analysis times for the given split.

"""
if member_id is None:
# Only interior state data files have member_id, to avoid duplicates
# we only look at the first member for all other categories
member_id = 0
pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT)
pattern = re.sub(r"{member_id:[^}]*}", "*", pattern)
pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern)

sample_dir = self.root_path / "samples" / split
sample_files = sample_dir.glob(pattern)
Expand Down Expand Up @@ -668,34 +676,6 @@ def grid_shape_state(self) -> CartesianGridShape:
ny, nx = self.config.grid_shape_state
return CartesianGridShape(x=nx, y=ny)

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""The boundary mask for the dataset. This is a binary mask that is 1
where the grid cell is on the boundary of the domain, and 0 otherwise.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `[grid_index]`.

"""
xy = self.get_xy(category="state", stacked=False)
xs = xy[:, :, 0]
ys = xy[:, :, 1]
# Check if x-coordinates are constant along columns
assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant"
# Check if y-coordinates are constant along rows
assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant"
# Extract unique x and y coordinates
x = xs[:, 0] # Unique x-coordinates (changes along the first axis)
y = ys[0, :] # Unique y-coordinates (changes along the second axis)
values = np.load(self.root_path / "static" / "border_mask.npy")
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
)
da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
return da_mask_stacked_xy

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""Return the standardization dataarray for the given category. This
should contain a `{category}_mean` and `{category}_std` variable for
Expand Down
Loading
Loading