Skip to content

Commit

Permalink
bugfixes after real-life testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Jun 6, 2024
1 parent 26f069c commit 1f1cbcc
Show file tree
Hide file tree
Showing 14 changed files with 173 additions and 121 deletions.
12 changes: 6 additions & 6 deletions calculate_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def main():
)
args = parser.parse_args()

config_loader = config.Config.from_file(args.data_config)
state_data = config_loader.process_dataset("state", split="train")
forcing_data = config_loader.process_dataset(
data_config = config.Config.from_file(args.data_config)
state_data = data_config.process_dataset("state", split="train")
forcing_data = data_config.process_dataset(
"forcing", split="train", apply_windowing=False
)

Expand All @@ -41,7 +41,7 @@ def main():

if forcing_data is not None:
forcing_mean, forcing_std = compute_stats(forcing_data)
combined_stats = config_loader["utilities"]["normalization"][
combined_stats = data_config["utilities"]["normalization"][
"combined_stats"
]

Expand All @@ -58,15 +58,15 @@ def main():
dict(variable=vars_to_combine)
] = combined_mean
forcing_std.loc[dict(variable=vars_to_combine)] = combined_std
window = config_loader["forcing"]["window"]
window = data_config["forcing"]["window"]
forcing_mean = xr.concat([forcing_mean] * window, dim="window").stack(
forcing_variable=("variable", "window")
)
forcing_std = xr.concat([forcing_std] * window, dim="window").stack(
forcing_variable=("variable", "window")
)
vars = forcing_data["variable"].values.tolist()
window = config_loader["forcing"]["window"]
window = data_config["forcing"]["window"]
forcing_vars = [f"{var}_{i}" for var in vars for i in range(window)]

print(
Expand Down
6 changes: 3 additions & 3 deletions create_boundary_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def main():
help="Number of grid-cells to set to True along each boundary",
)
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
mask = np.zeros(list(config_loader.grid_shape_state.values.values()))
data_config = config.Config.from_file(args.data_config)
mask = np.zeros(list(data_config.grid_shape_state.values.values()))

# Set the args.boundaries grid-cells closest to each boundary to True
mask[: args.boundaries, :] = True # top boundary
mask[-args.boundaries :, :] = True # noqa bottom boundary
mask[:, : args.boundaries] = True # left boundary
mask[:, -args.boundaries :] = True # noqa right boundary

mask = xr.Dataset({"mask": (["x", "y"], mask)})
mask = xr.Dataset({"mask": (["y", "x"], mask)})

print(f"Saving mask to {args.zarr_path}...")
mask.to_zarr(args.zarr_path, mode="w")
Expand Down
4 changes: 2 additions & 2 deletions create_forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def main():
parser.add_argument("--zarr_path", type=str, default="forcings.zarr")
args = parser.parse_args()

config_loader = config.Config.from_file(args.data_config)
dataset = config_loader.open_zarrs("state")
data_config = config.Config.from_file(args.data_config)
dataset = data_config.open_zarrs("state")
datetime_forcing = calculate_datetime_forcing(timesteps=dataset.time)

# Expand dimensions to match the target dataset
Expand Down
4 changes: 2 additions & 2 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ def main(input_args=None):
args = parser.parse_args(input_args)

# Load grid positions
config_loader = config.Config.from_file(args.data_config)
data_config = config.Config.from_file(args.data_config)
graph_dir_path = os.path.join("graphs", args.graph)
os.makedirs(graph_dir_path, exist_ok=True)

xy = config_loader.get_xy("static") # (2, N_y, N_x)
xy = data_config.get_xy("static") # (2, N_y, N_x)
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))

Expand Down
5 changes: 3 additions & 2 deletions docs/download_danra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Third-party
import xarray as xr

data_urls = [
Expand All @@ -18,8 +19,8 @@
ds = ds.chunk(chunk_dict)

for var in ds.variables:
if 'chunks' in ds[var].encoding:
del ds[var].encoding['chunks']
if "chunks" in ds[var].encoding:
del ds[var].encoding["chunks"]

ds.to_zarr(path, mode="w")
print("DONE")
63 changes: 41 additions & 22 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def coords_projection(self):
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)

@functools.cached_property
def step_length(self):
"""Return the step length of the dataset in hours."""
dataset = self.open_zarrs("state")
time = dataset.time.isel(time=slice(0, 2)).values
step_length_ns = time[1] - time[0]
step_length_hours = step_length_ns / np.timedelta64(1, "h")
return int(step_length_hours)

@functools.lru_cache()
def vars_names(self, category):
"""Return the names of the variables in the dataset."""
Expand Down Expand Up @@ -191,10 +200,10 @@ def filter_dimensions(self, dataset, transpose_array=True):
if isinstance(dataset, xr.Dataset)
else dataset["variable"].values.tolist()
)
print(
"\033[94mYour Dataarray has the following variables: ",
dataset_vars,
"\033[0m",

print( # noqa
f"\033[94mYour {dataset.attrs['category']} xr.Dataarray has the "
f"following variables: {dataset_vars} \033[0m",
)

return dataset
Expand Down Expand Up @@ -366,29 +375,19 @@ def filter_dataset_by_time(self, dataset, split="train"):
self.values["splits"][split]["start"],
self.values["splits"][split]["end"],
)
return dataset.sel(time=slice(start, end))

def process_dataset(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category."""
dataset = self.open_zarrs(category)
dataset = self.extract_vars(category, dataset)
dataset = self.filter_dataset_by_time(dataset, split)
dataset = self.stack_grid(dataset)
dataset = self.rename_dataset_dims_and_vars(category, dataset)
dataset = self.filter_dimensions(dataset)
dataset = self.convert_dataset_to_dataarray(dataset)
if "window" in self.values[category] and apply_windowing:
dataset = self.apply_window(category, dataset)
if category == "static" and "time" in dataset.dims:
dataset = dataset.isel(time=0, drop=True)

dataset = dataset.sel(time=slice(start, end))
dataset.attrs["split"] = split
return dataset

def apply_window(self, category, dataset=None):
"""Apply the forcing window to the forcing dataset."""
if dataset is None:
dataset = self.open_zarrs(category)
state_time = self.open_zarrs("state").time.values
if isinstance(dataset, xr.Dataset):
dataset = self.convert_dataset_to_dataarray(dataset)
state = self.open_zarrs("state")
state = self.filter_dataset_by_time(state, dataset.attrs["split"])
state_time = state.time.values
window = self[category].window
dataset = (
dataset.sel(time=state_time, method="nearest")
Expand All @@ -397,9 +396,29 @@ def apply_window(self, category, dataset=None):
.construct("window")
.stack(variable_window=("variable", "window"))
)
dataset = dataset.isel(time=slice(window // 2, -window // 2 + 1))
return dataset

def load_boundary_mask(self):
"""Load the boundary mask for the dataset."""
boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
return torch.tensor(boundary_mask.to_array().values)
return torch.tensor(
boundary_mask.mask.stack(grid=("y", "x")).values,
dtype=torch.float32,
).unsqueeze(1)

def process_dataset(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category."""
dataset = self.open_zarrs(category)
dataset = self.extract_vars(category, dataset)
dataset = self.filter_dataset_by_time(dataset, split)
dataset = self.stack_grid(dataset)
dataset = self.rename_dataset_dims_and_vars(category, dataset)
dataset = self.filter_dimensions(dataset)
dataset = self.convert_dataset_to_dataarray(dataset)
if "window" in self.values[category] and apply_windowing:
dataset = self.apply_window(category, dataset)
if category == "static" and "time" in dataset.dims:
dataset = dataset.isel(time=0, drop=True)

return dataset
16 changes: 12 additions & 4 deletions neural_lam/data_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: danra
state:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
- path: "data/danra/single_levels.zarr"
dims:
time: time
level: null
Expand All @@ -11,7 +11,7 @@ state:
lat_lon_names:
lon: lon
lat: lat
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr"
- path: "data/danra/height_levels.zarr"
dims:
time: time
level: altitude
Expand Down Expand Up @@ -41,7 +41,7 @@ state:
- 100
forcing:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
- path: "data/danra/single_levels.zarr"
dims:
time: time
level: null
Expand Down Expand Up @@ -82,7 +82,7 @@ forcing:
window: 3 # Number of time steps to use for forcing (odd)
static:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
- path: "data/danra/single_levels.zarr"
dims:
level: null
x: x
Expand All @@ -106,6 +106,7 @@ boundary:
level: level
x: longitude
y: latitude
grid: null
lat_lon_names:
lon: longitude
lat: latitude
Expand All @@ -114,6 +115,13 @@ boundary:
dims:
x: x
y: y
surface_vars:
- t2m
surface_units:
- K
atmosphere_vars: null
atmosphere_units: null
levels: null
window: 3
utilities:
normalization:
Expand Down
Loading

0 comments on commit 1f1cbcc

Please sign in to comment.