Skip to content

Commit

Permalink
Turn graph classes into python dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Jun 26, 2024
1 parent dddc7a7 commit 258b013
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 71 deletions.
123 changes: 76 additions & 47 deletions neural_lam/graphs/base_weather_graph.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,97 @@
# Standard library
import os
from dataclasses import dataclass

# Third-party
import torch
import torch.nn as nn


@dataclass
class BaseWeatherGraph(nn.Module):
"""
Graph object representing weather graph consisting of grid and mesh nodes
"""

def __init__(
self,
g2m_edge_index,
g2m_edge_features,
m2g_edge_index,
m2g_edge_features,
):
g2m_edge_index: torch.Tensor
g2m_edge_features: torch.Tensor
m2g_edge_index: torch.Tensor
m2g_edge_features: torch.Tensor

def __post_init__(self):
BaseWeatherGraph.check_subgraph(
self.g2m_edge_features, self.g2m_edge_index, "g2m"
)
BaseWeatherGraph.check_subgraph(
self.m2g_edge_features, self.m2g_edge_index, "m2g"
)

# TODO Checks that node indices align
# TODO Make all node indices start at 0

@staticmethod
def check_features(features, subgraph_name):
"""
Create a new graph from tensors
Check that feature tensor has the correct format
features: (2, num_features) tensor of features
subgraph_name: name of associated subgraph, used in error messages
"""
super().__init__()
assert isinstance(
features, torch.Tensor
), f"{subgraph_name} features is not a tensor"
assert features.dtype == torch.float32, (
f"Wrong data type for {subgraph_name} feature tensor: "
f"{features.dtype}"
)
assert len(features.shape) == 2, (
f"Wrong shape of {subgraph_name} feature tensor: "
f"{features.shape}"
)

# Store edge indices
self.g2m_edge_index = g2m_edge_index
self.m2g_edge_index = m2g_edge_index
@staticmethod
def check_edge_index(edge_index, subgraph_name):
"""
Check that edge index tensor has the correct format
# Store edge features
self.g2m_edge_features = g2m_edge_features
self.m2g_edge_features = m2g_edge_features
edge_index: (2, num_edges) tensor with edge index
subgraph_name: name of associated subgraph, used in error messages
"""
assert isinstance(
edge_index, torch.Tensor
), f"{subgraph_name} edge_index is not a tensor"
assert edge_index.dtype == torch.int64, (
f"Wrong data type for {subgraph_name} edge_index "
f"tensor: {edge_index.dtype}"
)
assert len(edge_index.shape) == 2, (
f"Wrong shape of {subgraph_name} edge_index tensor: "
f"{edge_index.shape}"
)
assert edge_index.shape[0] == 2, (
"Wrong shape of {subgraph_name} edge_index tensor: "
f"{edge_index.shape}"
)

# TODO Checks that node indices align
# TODO Make all node indices start at 0
@staticmethod
def check_subgraph(edge_features, edge_index, subgraph_name):
"""
Check that tensors associated with subgraph (edge index and features)
has the correct format
edge_features: (2, num_features) tensor of edge features
edge_index: (2, num_edges) tensor with edge index
subgraph_name: name of associated subgraph, used in error messages
"""
# Check individual tensors
BaseWeatherGraph.check_features(edge_features, subgraph_name)
BaseWeatherGraph.check_edge_index(edge_index, subgraph_name)

# Check compatibility
assert edge_features.shape[0] == edge_index.shape[1], (
f"Mismatch in shape of {subgraph_name} edge_index "
f"(edge_index.shape) and features {edge_features.shape}"
)

def num_mesh_nodes(self):
# TODO use g2m
Expand Down Expand Up @@ -75,30 +134,10 @@ def _load_subgraph_from_dir(graph_dir_path, subgraph_name):
graph_dir_path, f"{subgraph_name}_edge_index.pt"
)

# Check edge_index
assert edge_index.dtype == torch.int64, (
f"Wrong data type for {subgraph_name} edge_index "
f"tensor: {edge_index.dtype}"
)
assert len(edge_index.shape) == 2, (
f"Wrong shape of {subgraph_name} edge_index tensor: "
f"{edge_index.shape}"
)
assert edge_index.shape[0] == 2, (
"Wrong shape of {subgraph_name} edge_index tensor: "
f"{edge_index.shape}"
)

edge_features = BaseWeatherGraph._load_feature_tensor(
graph_dir_path, f"{subgraph_name}_features.pt"
)

# Check compatibility
assert edge_features.shape[0] == edge_index.shape[1], (
f"Mismatch in shape of {subgraph_name} edge_index "
f"(edge_index.shape) and features {edge_features.shape}"
)

return edge_index, edge_features

@staticmethod
Expand All @@ -110,16 +149,6 @@ def _load_feature_tensor(graph_dir_path, file_name):
graph_dir_path, file_name
)

# Check features
assert features.dtype == torch.float32, (
f"Wrong data type for {file_name} graph feature tensor: "
f"{features.dtype}"
)
assert len(features.shape) == 2, (
f"Wrong shape of {file_name} graph feature tensor: "
f"{features.shape}"
)

return features

@staticmethod
Expand Down
39 changes: 15 additions & 24 deletions neural_lam/graphs/flat_weather_graph.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
# Standard library
from dataclasses import dataclass

# Third-party
import torch

# First-party
from neural_lam.graphs.base_weather_graph import BaseWeatherGraph


@dataclass
class FlatWeatherGraph(BaseWeatherGraph):
"""
Graph object representing weather graph consisting of grid and mesh nodes
"""

def __init__(
self,
g2m_edge_index,
g2m_edge_features,
m2g_edge_index,
m2g_edge_features,
m2m_edge_index,
m2m_edge_features,
mesh_node_features,
):
"""
Create a new graph from tensors
"""
super().__init__(
g2m_edge_index,
g2m_edge_features,
m2g_edge_index,
m2g_edge_features,
)

# Store mesh tensors
self.m2m_edge_index = m2m_edge_index
self.m2m_edge_features = m2m_edge_features
self.mesh_node_features = mesh_node_features
m2m_edge_index: torch.Tensor
m2m_edge_features: torch.Tensor
mesh_node_features: torch.Tensor

def __post_init__(self):
BaseWeatherGraph.check_subgraph(
self.m2m_edge_features, self.m2m_edge_index, "m2m"
)
BaseWeatherGraph.check_features(self.mesh_node_features, "mesh nodes")
# TODO Checks that node indices align

def num_mesh_nodes(self):
Expand Down

0 comments on commit 258b013

Please sign in to comment.