Skip to content

Commit

Permalink
Get boundary static features from second datastore
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Dec 3, 2024
1 parent 4bcaa4b commit 6afc50c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 42 deletions.
71 changes: 30 additions & 41 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,56 +32,32 @@ def __init__(
args,
config: NeuralLAMConfig,
datastore: BaseDatastore,
datastore_boundary: Union[BaseDatastore, None],
):
super().__init__()
self.save_hyperparameters(ignore=["datastore"])
self.args = args
self._datastore = datastore
num_state_vars = datastore.get_num_data_vars(category="state")
num_forcing_vars = datastore.get_num_data_vars(category="forcing")
da_static_features = datastore.get_dataarray(
category="static", split=None
)
da_state_stats = datastore.get_standardization_dataarray(
category="state"
)

num_past_forcing_steps = args.num_past_forcing_steps
num_future_forcing_steps = args.num_future_forcing_steps

# TODO: Set based on existing of boundary forcing datastore
# TODO: Adjust what is stored here based on self.boundary_forced
self.boundary_forced = False

# Set up boundary mask
boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
).unsqueeze(
1
) # add feature dim

self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
"interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border

# Load static features for grid/data, NB: self.predict_step assumes
# dimension order to be (grid_index, static_feature)
arr_static = da_static_features.transpose(
"grid_index", "static_feature"
).values
static_features_torch = torch.tensor(arr_static, dtype=torch.float32)
self.register_buffer(
"grid_static_features",
static_features_torch[self.interior_mask[:, 0].to(torch.bool)],
persistent=False,
# Load static features for grid
da_static_features = datastore.get_dataarray(
category="static", split=None
)
self.register_buffer(
"boundary_static_features",
static_features_torch[self.boundary_mask[:, 0].to(torch.bool)],
"grid_static_features",
torch.tensor(da_static_features.values, dtype=torch.float32),
persistent=False,
)

# Load stats for rescaling and weights
da_state_stats = datastore.get_standardization_dataarray(
category="state"
)
state_stats = {
"state_mean": torch.tensor(
da_state_stats.state_mean.values, dtype=torch.float32
Expand Down Expand Up @@ -137,16 +113,29 @@ def __init__(
* num_forcing_vars
* (num_past_forcing_steps + num_future_forcing_steps + 1)
)

# If datastore_boundary is given, the model is forced from the boundary
self.boundary_forced = datastore_boundary is not None

if self.boundary_forced:
self.boundary_dim = self.grid_dim # TODO Compute separately
# Load static features for grid
da_boundary_static_features = datastore_boundary.get_dataarray(
category="static", split=None
)
self.register_buffer(
"boundary_static_features",
torch.tensor(
da_boundary_static_features.values, dtype=torch.float32
),
persistent=False,
)

(
self.num_boundary_nodes,
boundary_static_dim, # TODO Will need for computation below
boundary_static_dim,
) = self.boundary_static_features.shape
self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes
else:
# Only interior grid nodes
self.num_input_nodes = self.num_grid_nodes
# TODO Compute boundary input dim separately
self.boundary_dim = self.grid_dim

# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
Expand Down
7 changes: 6 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,12 @@ def main(input_args=None):

# Load model parameters Use new args for model
ModelClass = MODELS[args.model]
model = ModelClass(args, config=config, datastore=datastore)
model = ModelClass(
args,
config=config,
datastore=datastore,
datastore_boundary=datastore_boundary,
)

if args.eval:
prefix = f"eval-{args.eval}-"
Expand Down

0 comments on commit 6afc50c

Please sign in to comment.