Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add tests for meps dataset #38

Merged
merged 31 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
57edd74
added testing of loading data, creating graphs, and training model.
SimonKamuk May 22, 2024
92aa490
Merge branch 'mllam:main' into main
SimonKamuk May 22, 2024
4e17efb
added test to test name
SimonKamuk May 22, 2024
7fa7cdd
linting
SimonKamuk May 22, 2024
569d061
made create_mesh callable as python function with arguments.
SimonKamuk May 23, 2024
1ebe900
added github ci/cd for running tests with pytest
SimonKamuk May 23, 2024
0e96e88
removed coverage from test ci/cd
SimonKamuk May 23, 2024
2339ed0
fixed error in cicd
SimonKamuk May 23, 2024
5d3f834
removed astroid from requirements, causes codespell error, assuming i…
SimonKamuk May 23, 2024
8d733b7
simplified requirements
SimonKamuk May 23, 2024
7ee8821
removed commas in requirements
SimonKamuk May 23, 2024
9a5f83c
added downloading of test data from EWC using pooch
SimonKamuk May 24, 2024
c7d1d08
added pooch to requirements.txt
SimonKamuk May 24, 2024
2667b6c
updated test dataset
SimonKamuk May 24, 2024
0c7edd4
Disabled latex to enable running on github without having to install …
SimonKamuk May 27, 2024
9352949
only use latex if available
SimonKamuk May 27, 2024
4995de0
included change requests from leifdenby:
SimonKamuk May 28, 2024
d33180f
added comment
SimonKamuk May 28, 2024
fb72943
Merge branch 'mllam:main' into main
SimonKamuk May 30, 2024
e6c2c36
minor requested changes
SimonKamuk May 30, 2024
43558dc
Merge branch 'main' of github.com:SimonKamuk/neural-lam into feature_…
SimonKamuk May 30, 2024
3d77ac4
updated changelog, added cicd badges
SimonKamuk May 31, 2024
d390308
moved installation of torch-geometric from requirements to github tes…
SimonKamuk May 31, 2024
de4efba
changed name of unit test badge
SimonKamuk May 31, 2024
b0c4bed
added caching of test data
SimonKamuk Jun 3, 2024
4868db4
Merge branch 'main' of github.com:SimonKamuk/neural-lam into feature_…
SimonKamuk Jun 3, 2024
18e55a4
fix for caching
SimonKamuk Jun 3, 2024
4f75307
tried fix for caching test data
SimonKamuk Jun 3, 2024
aceb47c
updated changelog
SimonKamuk Jun 3, 2024
a6f8089
separated saving and restoring of cache
SimonKamuk Jun 3, 2024
561a26e
Merge branch 'main' of github.com:SimonKamuk/neural-lam into feature_…
SimonKamuk Jun 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Run Unit Test via Pytest

on:
# trigger on pushes to any branch, but not main
push:
branches-ignore:
- main
# and also on PRs to main
pull_request:
branches:
- main

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest -v -s
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def prepend_node_index(graph, new_index):
return networkx.relabel_nodes(graph, to_mapping, copy=True)


def main():
def main(input_args=None):
parser = ArgumentParser(description="Graph generation arguments")
parser.add_argument(
"--data_config",
Expand Down Expand Up @@ -186,7 +186,7 @@ def main():
default=0,
help="Generate hierarchical mesh graph (default: 0, no)",
)
args = parser.parse_args()
args = parser.parse_args(input_args)

# Load grid positions
config_loader = config.Config.from_file(args.data_config)
Expand Down
4 changes: 3 additions & 1 deletion neural_lam/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
import os
import shutil

# Third-party
import numpy as np
Expand Down Expand Up @@ -250,7 +251,8 @@ def fractional_plot_bundle(fraction):
Get the tueplots bundle, but with figure width as a fraction of
the page width.
"""
bundle = bundles.neurips2023(usetex=True, family="serif")
usetex = True if shutil.which("latex") else False
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
bundle = bundles.neurips2023(usetex=usetex, family="serif")
bundle.update(figsizes.neurips2023())
original_figsize = bundle["figure.figsize"]
bundle["figure.figsize"] = (
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
subplot_kw={"projection": data_config.coords_projection()},
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
subplot_kw={"projection": data_config.coords_projection},
)

# Plot pred and target
Expand Down Expand Up @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):

fig, ax = plt.subplots(
figsize=(5, 4.8),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

ax.coastlines() # Add coastline outlines
Expand Down
12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ Cartopy>=0.22.0
pyproj>=3.4.1
tueplots>=0.0.8
plotly>=5.15.0
torch-geometric>=2.5.2
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
loguru>=0.7.2
xarray>=2024.3.0
zarr>=2.17.2
dask>=2024.4.2
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved

# for dev
pre-commit>=2.15.0
codespell>=2.0.0
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
black>=21.9b0
isort>=5.9.3
flake8>=4.0.1
pylint>=3.0.3
pytest>=8.1.1
pooch>=1.8.1
Empty file added tests/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions tests/test_mllam_dataset.py
leifdenby marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Standard library
import os

# Third-party
import pooch
leifdenby marked this conversation as resolved.
Show resolved Hide resolved

# First-party
from create_mesh import main as create_mesh
from neural_lam.config import Config
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
from train_model import main as train_model

os.environ["WANDB_DISABLED"] = "true"
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved


def test_retrieve_data_ewc():
# Initializing variables for the client
S3_BUCKET_NAME = "mllam-testdata"
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip"
S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH])
known_hash = (
"98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7"
)

pooch.retrieve(
url=S3_FULL_PATH,
known_hash=known_hash,
processor=pooch.Unzip(extract_dir=""),
path="data",
fname="meps_example_reduced.zip",
)


def test_load_reduced_meps_dataset():
data_config_file = "data/meps_example_reduced/data_config.yaml"
dataset_name = "meps_example_reduced"

dataset = WeatherDataset(dataset_name="meps_example_reduced")
config = Config.from_file(data_config_file)

var_names = config.values["dataset"]["var_names"]
var_units = config.values["dataset"]["var_units"]
var_longnames = config.values["dataset"]["var_longnames"]

assert len(var_names) == len(var_longnames)
assert len(var_names) == len(var_units)

# TODO: can these two variables be loaded from elsewhere?
n_grid_static_features = 4
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved
n_input_steps = 2
SimonKamuk marked this conversation as resolved.
Show resolved Hide resolved

n_forcing_features = config.values["dataset"]["num_forcing_features"]
n_state_features = len(var_names)
n_prediction_timesteps = dataset.sample_length - n_input_steps

nx, ny = config.values["grid_shape_state"]
n_grid = nx * ny

# check that the dataset is not empty
assert len(dataset) > 0

# get the first item
init_states, target_states, forcing = dataset[0]

# check that the shapes of the tensors are correct
assert init_states.shape == (n_input_steps, n_grid, n_state_features)
assert target_states.shape == (
n_prediction_timesteps,
n_grid,
n_state_features,
)
assert forcing.shape == (
n_prediction_timesteps,
n_grid,
n_forcing_features,
)

static_data = load_static_data(dataset_name=dataset_name)

required_props = {
"border_mask",
"grid_static_features",
"step_diff_mean",
"step_diff_std",
"data_mean",
"data_std",
"param_weights",
}

# check the sizes of the props
assert static_data["border_mask"].shape == (n_grid, 1)
assert static_data["grid_static_features"].shape == (
n_grid,
n_grid_static_features,
)
assert static_data["step_diff_mean"].shape == (n_state_features,)
assert static_data["step_diff_std"].shape == (n_state_features,)
assert static_data["data_mean"].shape == (n_state_features,)
assert static_data["data_std"].shape == (n_state_features,)
assert static_data["param_weights"].shape == (n_state_features,)

assert set(static_data.keys()) == required_props


def test_create_graph_reduced_meps_dataset():
args = [
"--graph=hierarchical",
"--hierarchical=1",
"--data_config=data/meps_example_reduced/data_config.yaml",
"--levels=2",
]
create_mesh(args)


def test_train_model_reduced_meps_dataset():
args = [
"--model=hi_lam",
"--data_config=data/meps_example_reduced/data_config.yaml",
"--n_workers=4",
"--epochs=1",
"--graph=hierarchical",
"--hidden_dim=16",
"--hidden_layers=1",
"--processor_layers=1",
"--ar_steps=1",
"--eval=val",
"--n_example_pred=0",
]
train_model(args)
4 changes: 2 additions & 2 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
}


def main():
def main(input_args=None):
"""
Main function for training and evaluating models
"""
Expand Down Expand Up @@ -206,7 +206,7 @@ def main():
default={},
help="Dict with variables and lead times to log watched metrics for",
)
args = parser.parse_args()
args = parser.parse_args(input_args)

config_loader = config.Config.from_file(args.data_config)

Expand Down
Loading