Skip to content

Commit

Permalink
Add standardization test
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Jan 16, 2025
1 parent 3f53b53 commit 1a28459
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,24 @@ def test_plot_example_from_datastore(datastore_name):

assert fig is not None
assert fig.get_axes()


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
@pytest.mark.parametrize("category", ("state", "static"))
def test_get_standardized_da(datastore_name, category):
"""Check that dataarray is actually standardized when calling
get_dataarray with standardize=True"""
datastore = init_datastore_example(datastore_name)
ds_stats = datastore.get_standardization_dataarray(category=category)

mean = ds_stats[f"{category}_mean"]
std = ds_stats[f"{category}_std"]

non_standard_da = datastore.get_dataarray(
category=category, split="train", standardize=False
)
standard_da = datastore.get_dataarray(
category=category, split="train", standardize=True
)

assert np.allclose(standard_da, (non_standard_da - mean) / std, atol=1e-6)

0 comments on commit 1a28459

Please sign in to comment.