Skip to content

Commit

Permalink
cleanup: boundary_mask, zarr-opening, utils
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Jun 4, 2024
1 parent a748903 commit 74b4a10
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 143 deletions.
2 changes: 1 addition & 1 deletion create_forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main():
args = parser.parse_args()

config_loader = config.Config.from_file(args.data_config)
dataset = config_loader.open_zarr("state")
dataset = config_loader.open_zarrs("state")
datetime_forcing = calculate_datetime_forcing(timesteps=dataset.time)

# Expand dimensions to match the target dataset
Expand Down
25 changes: 15 additions & 10 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def num_data_vars(self, category):

return surface_vars_count + atmosphere_vars_count * levels_count

def open_zarr(self, category):
def open_zarrs(self, category):
"""Open the zarr dataset for the given category."""
zarr_configs = self.values[category]["zarrs"]

try:
datasets = []
for config in zarr_configs:
dataset_path = config["path"]
dataset = xr.open_zarr(dataset_path, consolidated=True)
dataset = xr.open_zarrs(dataset_path, consolidated=True)
datasets.append(dataset)
merged_dataset = xr.merge(datasets)
merged_dataset.attrs["category"] = category
Expand Down Expand Up @@ -223,7 +223,7 @@ def reshape_grid_to_2d(self, dataset, grid_shape=None):
@functools.lru_cache()
def get_xy(self, category):
"""Return the x, y coordinates of the dataset."""
dataset = self.open_zarr(category)
dataset = self.open_zarrs(category)
x, y = dataset.x.values, dataset.y.values
if x.ndim == 1:
x, y = np.meshgrid(x, y)
Expand All @@ -244,7 +244,7 @@ def load_normalization_stats(self, category, datatype="torch"):
f"{stats_path}"
)
return None
stats = xr.open_zarr(stats_path, consolidated=True)
stats = xr.open_zarrs(stats_path, consolidated=True)
if i == 0:
combined_stats = stats
else:
Expand Down Expand Up @@ -294,7 +294,7 @@ def load_normalization_stats(self, category, datatype="torch"):
# def assign_lat_lon_coords(self, category, dataset=None):
# """Process the latitude and longitude names of the dataset."""
# if dataset is None:
# dataset = self.open_zarr(category)
# dataset = self.open_zarrs(category)
# lat_lon_names = {}
# for zarr_config in self.values[category]["zarrs"]:
# lat_lon_names.update(zarr_config["lat_lon_names"])
Expand All @@ -311,7 +311,7 @@ def load_normalization_stats(self, category, datatype="torch"):
def extract_vars(self, category, dataset=None):
"""Extract the variables from the dataset."""
if dataset is None:
dataset = self.open_zarr(category)
dataset = self.open_zarrs(category)
surface_vars = (
dataset[self[category].surface_vars]
if self[category].surface_vars
Expand Down Expand Up @@ -354,7 +354,7 @@ def rename_dataset_dims_and_vars(self, category, dataset=None):
"""Rename the dimensions and variables of the dataset."""
convert = False
if dataset is None:
dataset = self.open_zarr(category)
dataset = self.open_zarrs(category)
elif isinstance(dataset, xr.DataArray):
convert = True
dataset = dataset.to_dataset("variable")
Expand Down Expand Up @@ -387,7 +387,7 @@ def filter_dataset_by_time(self, dataset, split="train"):

def process_dataset(self, category, split="train", apply_windowing=True):
"""Process the dataset for the given category."""
dataset = self.open_zarr(category)
dataset = self.open_zarrs(category)
dataset = self.extract_vars(category, dataset)
dataset = self.filter_dataset_by_time(dataset, split)
dataset = self.stack_grid(dataset)
Expand All @@ -402,8 +402,8 @@ def process_dataset(self, category, split="train", apply_windowing=True):
def apply_window(self, category, dataset=None):
"""Apply the forcing window to the forcing dataset."""
if dataset is None:
dataset = self.open_zarr(category)
state_time = self.open_zarr("state").time.values
dataset = self.open_zarrs(category)
state_time = self.open_zarrs("state").time.values
window = self[category].window
dataset = (
dataset.sel(time=state_time, method="nearest")
Expand All @@ -413,3 +413,8 @@ def apply_window(self, category, dataset=None):
.stack(variable_window=("variable", "window"))
)
return dataset

def load_boundary_mask(self):
"""Load the boundary mask for the dataset."""
boundary_mask = xr.open_zarr(self.values["boundary"]["mask"]["path"])
return boundary_mask.to_array().values
8 changes: 6 additions & 2 deletions neural_lam/data_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static:
atmosphere_units: null
levels: null
boundary:
zarrs:
zarrs: # This is not used currently, but soon ERA% boundaries will be used
- path: "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
dims:
time: time
Expand All @@ -105,7 +105,11 @@ boundary:
lat_lon_names:
lon: longitude
lat: latitude
mask: boundary_mask
mask:
path: "boundary_mask.zarr"
dims:
x: x
y: y
window: 3
utilities:
normalization:
Expand Down
55 changes: 9 additions & 46 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,18 @@ def __init__(self, args):

# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
self.grid_output_dim = self.config_loader.num_data_vars("state")
if self.output_std:
# Pred. dim. in grid cell
self.grid_output_dim = 2 * self.config_loader.num_data_vars(
"state"
)
else:
# Pred. dim. in grid cell
self.grid_output_dim = self.config_loader.num_data_vars("state")
self.grid_output_dim = 2 * self.grid_output_dim

# grid_dim from data + static
(
self.num_grid_nodes,
grid_static_dim,
) = self.grid_static_features.shape
self.grid_dim = (
2 * self.config_loader.num_data_vars("state")
2 * self.grid_output_dim
+ grid_static_dim
+ self.config_loader.num_data_vars("forcing")
* self.config_loader.forcing.window
Expand All @@ -60,14 +56,15 @@ def __init__(self, args):
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)

border_mask = torch.zeros(self.num_grid_nodes, 1)
self.register_buffer("border_mask", border_mask, persistent=False)
boundary_mask = self.config_loader.load_boundary_mask()
self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
"interior_mask", 1.0 - self.border_mask, persistent=False
"interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border

self.step_length = args.step_length # Number of hours per pred. step
# Number of hours per pred. step
self.step_length = self.config_loader.step_length
self.val_metrics = {
"mse": [],
}
Expand All @@ -88,21 +85,6 @@ def __init__(self, args):
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []

# Load normalization statistics
self.normalization_stats = (
self.config_loader.load_normalization_stats()
)
if self.normalization_stats is not None:
for (
var_name,
var_data,
) in self.normalization_stats.data_vars.items():
self.register_buffer(
f"{var_name}",
torch.tensor(var_data.values),
persistent=False,
)

def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
Expand Down Expand Up @@ -157,7 +139,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

# Overwrite border with true state
new_state = (
self.border_mask * border_state
self.boundary_mask * border_state
+ self.interior_mask * pred_state
)

Expand Down Expand Up @@ -203,25 +185,6 @@ def common_step(self, batch):

return prediction, target_states, pred_std

def on_after_batch_transfer(self, batch, dataloader_idx):
"""Normalize Batch data after transferring to the device."""
if self.normalization_stats is not None:
init_states, target_states, forcing_features, _, _ = batch
init_states = (init_states - self.mean) / self.std
target_states = (target_states - self.mean) / self.std
forcing_features = (
forcing_features - self.forcing_mean
) / self.forcing_std
# boundary_features = ( boundary_features - self.boundary_mean ) /
# self.boundary_std
batch = (
init_states,
target_states,
forcing_features,
# boundary_features,
)
return batch

def training_step(self, batch):
"""
Train on single batch
Expand Down
75 changes: 0 additions & 75 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,11 @@
import os

# Third-party
import numpy as np
import torch
from torch import nn
from tueplots import bundles, figsizes


def load_dataset_stats(dataset_name, device="cpu"):
"""
Load arrays with stored dataset statistics from pre-processing
"""
static_dir_path = os.path.join("data", dataset_name, "static")

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
)

data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

flux_stats = loads_file("flux_stats.pt") # (2,)
flux_mean, flux_std = flux_stats

return {
"data_mean": data_mean,
"data_std": data_std,
"flux_mean": flux_mean,
"flux_std": flux_std,
}


def load_static_data(dataset_name, device="cpu"):
"""
Load static files related to dataset
"""
static_dir_path = os.path.join("data", dataset_name, "static")

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
)

# Load border mask, 1. if node is part of border, else 0.
border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy"))
border_mask = (
torch.tensor(border_mask_np, dtype=torch.float32, device=device)
.flatten(0, 1)
.unsqueeze(1)
) # (N_grid, 1)

grid_static_features = loads_file(
"grid_features.pt"
) # (N_grid, d_grid_static)

# Load step diff stats
step_diff_mean = loads_file("diff_mean.pt") # (d_f,)
step_diff_std = loads_file("diff_std.pt") # (d_f,)

# Load parameter std for computing validation errors in original data scale
data_mean = loads_file("parameter_mean.pt") # (d_features,)
data_std = loads_file("parameter_std.pt") # (d_features,)

# Load loss weighting vectors
param_weights = torch.tensor(
np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
dtype=torch.float32,
device=device,
) # (d_f,)

return {
"border_mask": border_mask,
"grid_static_features": grid_static_features,
"step_diff_mean": step_diff_mean,
"step_diff_std": step_diff_std,
"data_mean": data_mean,
"data_std": data_std,
"param_weights": param_weights,
}


class BufferList(nn.Module):
"""
A list of torch buffer tensors that sit together as a Module with no
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, data_config, title=None, step_length=3):
def plot_error_map(errors, data_config, title=None, step_length=1):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
Expand Down
8 changes: 0 additions & 8 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,6 @@ def main():
default="wmse",
help="Loss function to use, see metric.py (default: wmse)",
)
parser.add_argument(
"--step_length",
type=int,
default=1,
help="Step length in hours to consider single time step 1-3 "
"(default: 1)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="learning rate (default: 0.001)"
)
Expand Down Expand Up @@ -222,7 +215,6 @@ def main():

# Asserts for arguments
assert args.model in MODELS, f"Unknown model: {args.model}"
assert args.step_length <= 3, "Too high step length"
assert args.eval in (
None,
"val",
Expand Down

0 comments on commit 74b4a10

Please sign in to comment.