Skip to content

Commit

Permalink
training working with mllam datastore!
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Denby committed Jul 25, 2024
1 parent 58f5d99 commit 64d43a6
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ tags

# pdm (https://pdm-project.org/en/stable/)
.pdm-python
.venv

# exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm
pdm.lock
19 changes: 11 additions & 8 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
super().__init__()
self.save_hyperparameters()
self.args = args
self._datastore = datastore
# XXX: should be this be somewhere else?
split = "train"
num_state_vars = datastore.get_num_data_vars(category="state")
Expand Down Expand Up @@ -429,18 +430,18 @@ def plot_examples(self, batch, n_examples, prediction=None):
# Create one figure per variable at this time step
var_figs = [
vis.plot_prediction(
pred_t[:, var_i],
target_t[:, var_i],
self.interior_mask[:, 0],
self.datastore,
pred=pred_t[:, var_i],
target=target_t[:, var_i],
obs_mask=self.interior_mask[:, 0],
datastore=self.datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
self.data_config.vars_names("state"),
self.data_config.vars_units("state"),
self._datastore.get_vars_names("state"),
self._datastore.get_vars_units("state"),
var_vranges,
)
)
Expand All @@ -451,7 +452,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
self.data_config.vars_names("state"), var_figs
self._datastore.get_vars_names("state"), var_figs
)
}
)
Expand Down Expand Up @@ -485,7 +486,9 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
log_dict = {}
metric_fig = vis.plot_error_map(
metric_tensor, self.data_config, step_length=self.step_length
errors=metric_tensor,
datastore=self._datastore,
step_length=self.step_length,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
Expand Down
30 changes: 18 additions & 12 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

# Local
from . import utils
from .datastore.base import BaseCartesianDatastore


@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, data_config, title=None, step_length=1):
def plot_error_map(
errors, datastore: BaseCartesianDatastore, title=None, step_length=1
):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
Expand Down Expand Up @@ -48,11 +51,10 @@ def plot_error_map(errors, data_config, title=None, step_length=1):
ax.set_xlabel("Lead time (h)", size=label_size)

ax.set_yticks(np.arange(d_f))
var_names = datastore.get_vars_names(category="state")
var_units = datastore.get_vars_units(category="state")
y_ticklabels = [
f"{name} ({unit})"
for name, unit in zip(
data_config.vars_names("state"), data_config.vars_units("state")
)
f"{name} ({unit})" for name, unit in zip(var_names, var_units)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)

Expand All @@ -64,7 +66,12 @@ def plot_error_map(errors, data_config, title=None, step_length=1):

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
pred, target, obs_mask, data_config, title=None, vrange=None
pred,
target,
obs_mask,
datastore: BaseCartesianDatastore,
title=None,
vrange=None,
):
"""Plot example prediction and grond truth.
Expand All @@ -77,12 +84,11 @@ def plot_prediction(
else:
vmin, vmax = vrange

extent = data_config.get_xy_extent("state")
extent = datastore.get_xy_extent("state")

# Set up masking of border region
mask_reshaped = obs_mask.reshape(
list(data_config.grid_shape_state.values.values())
)
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
Expand All @@ -91,14 +97,14 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
subplot_kw={"projection": data_config.coords_projection},
subplot_kw={"projection": datastore.coords_projection},
)

# Plot pred and target
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
data_grid = (
data.reshape(list(data_config.grid_shape_state.values.values()))
data.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
.numpy()
)
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Standard library
import os

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
os.environ["WANDB_DISABLED"] = "true"
82 changes: 82 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Standard library
from pathlib import Path

# Third-party
import pytest
import pytorch_lightning as pl
import torch
import wandb
from test_datastores import DATASTORES, init_datastore

# First-party
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataModule


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_training(datastore_name):
datastore = init_datastore(datastore_name)

if torch.cuda.is_available():
device_name = "cuda"
torch.set_float32_matmul_precision(
"high"
) # Allows using Tensor Cores on A100s
else:
device_name = "cpu"

trainer = pl.Trainer(
max_epochs=3,
deterministic=True,
strategy="ddp",
accelerator=device_name,
log_every_n_steps=1,
)

graph_name = "1level"

graph_dir_path = Path(datastore.root_path) / "graph" / graph_name

if not graph_dir_path.exists():
create_graph_from_datastore(
datastore=datastore,
output_root_path=str(graph_dir_path),
n_max_levels=1,
)

data_module = WeatherDataModule(
datastore=datastore,
ar_steps_train=3,
ar_steps_eval=5,
standardize=True,
batch_size=2,
num_workers=1,
forcing_window_size=3,
)

class ModelArgs:
output_std = False
loss = "mse"
restore_opt = False
n_example_pred = 1
# XXX: this should be superfluous when we have already defined the
# model object no?
graph = graph_name
hidden_dim = 8
hidden_layers = 1
processor_layers = 4
mesh_aggr = "sum"
lr = 1.0e-3
val_steps_to_log = [1]
metrics_watch = []

model_args = ModelArgs()

model = GraphLAM( # noqa
args=model_args,
forcing_window_size=data_module.forcing_window_size,
datastore=datastore,
)
wandb.init()
trainer.fit(model=model, datamodule=data_module)

0 comments on commit 64d43a6

Please sign in to comment.