Skip to content

Commit

Permalink
add check and print of train/test/val split in MDPDatastore
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Oct 2, 2024
1 parent 8e7b2e6 commit e0300fb
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,24 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
print(f"{category}: {' '.join(self.get_vars_names(category))}")
var_names = self.get_vars_names(category)
print(f" {category:<8s}: {' '.join(var_names)}")

# check that all three train/val/test splits are available
required_splits = ["train", "val", "test"]
available_splits = list(self._ds.splits.split_name.values)
if not all(split in available_splits for split in required_splits):
raise ValueError(
f"Missing required splits: {required_splits} in available "
f"splits: {available_splits}"
)

print("With the following splits (over time):")
for split in required_splits:
da_split = self._ds.splits.sel(split_name=split)
da_split_start = da_split.sel(split_part="start").load().item()
da_split_end = da_split.sel(split_part="end").load().item()
print(f" {split:<8s}: {da_split_start} to {da_split_end}")

# find out the dimension order for the stacking to grid-index
dim_order = None
Expand Down

0 comments on commit e0300fb

Please sign in to comment.