Skip to content

Commit

Permalink
Add tests for flat graph class
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Jun 26, 2024
1 parent 258b013 commit 183b24b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
1 change: 1 addition & 0 deletions neural_lam/graphs/flat_weather_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FlatWeatherGraph(BaseWeatherGraph):
mesh_node_features: torch.Tensor

def __post_init__(self):
super().__post_init__()
BaseWeatherGraph.check_subgraph(
self.m2m_edge_features, self.m2m_edge_index, "m2m"
)
Expand Down
100 changes: 100 additions & 0 deletions tests/test_graph_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Standard library
import copy

# Third-party
import pytest
import torch

# First-party
from neural_lam.graphs.flat_weather_graph import FlatWeatherGraph


def create_dummy_graph_tensors():
"""
Create dummy tensors for instantiating a flat graph
"""
num_grid = 10
num_mesh = 5
feature_dim = 3

return {
"g2m_edge_index": torch.zeros(2, num_grid, dtype=torch.long),
"g2m_edge_features": (
torch.zeros(num_grid, feature_dim, dtype=torch.float32)
),
"m2g_edge_index": torch.zeros(2, num_grid, dtype=torch.long),
"m2g_edge_features": (
torch.zeros(num_grid, feature_dim, dtype=torch.float32)
),
"m2m_edge_index": torch.zeros(2, num_mesh, dtype=torch.long),
"m2m_edge_features": (
torch.zeros(num_mesh, feature_dim, dtype=torch.float32)
),
"mesh_node_features": (
torch.zeros(num_mesh, feature_dim, dtype=torch.float32)
),
}


def test_create_flat_graph():
"""
Test that a Flat weather graph can be created with correct tensors
"""
FlatWeatherGraph(**create_dummy_graph_tensors())


@pytest.mark.parametrize(
"subgraph_name,tensor_name",
[
(subgraph_name, tensor_name)
for subgraph_name in ("g2m", "m2g", "m2m")
for tensor_name in ("edge_features", "edge_index")
]
+ [("mesh", "node_features")],
)
def test_dtypes_flat_graph(subgraph_name, tensor_name):
"""
Test that wrong data types properly raises errors
"""
dummy_tensors = create_dummy_graph_tensors()

# Test non-tensor input
dummy_copy = copy.copy(dummy_tensors)
dummy_copy[f"{subgraph_name}_{tensor_name}"] = 1 # Not a torch.Tensor

with pytest.raises(AssertionError) as assertinfo:
FlatWeatherGraph(**dummy_copy)
assert subgraph_name in str(
assertinfo
), "AssertionError did not contain {subgraph_name}"

# Test wrong data type
dummy_copy = copy.copy(dummy_tensors)
tensor_key = f"{subgraph_name}_{tensor_name}"
dummy_copy[tensor_key] = dummy_copy[tensor_key].to(torch.float16)

with pytest.raises(AssertionError) as assertinfo:
FlatWeatherGraph(**dummy_copy)
assert subgraph_name in str(
assertinfo
), "AssertionError did not contain {subgraph_name}"


@pytest.mark.parametrize("subgraph_name", ["g2m", "m2g", "m2m"])
def test_shape_match_flat_graph(subgraph_name):
"""
Test that shape mismatch between features and edge index
properly raises errors
"""
dummy_tensors = create_dummy_graph_tensors()

tensor_key = f"{subgraph_name}_edge_features"
original_shape = dummy_tensors[tensor_key].shape
new_shape = (original_shape[0] + 1, original_shape[1])
dummy_tensors[tensor_key] = torch.zeros(*new_shape, dtype=torch.float32)

with pytest.raises(AssertionError) as assertinfo:
FlatWeatherGraph(**dummy_tensors)
assert subgraph_name in str(
assertinfo
), "AssertionError did not contain {subgraph_name}"

0 comments on commit 183b24b

Please sign in to comment.