Skip to content

Commit

Permalink
removed loss config
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Feb 23, 2024
1 parent fa2d604 commit ecbe8fc
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 27 deletions.
7 changes: 5 additions & 2 deletions lagrangebench/case_setup/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jax_md import space
from jax_md.dataclasses import dataclass, static_field
from jax_md.partition import NeighborList, NeighborListFormat
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

from lagrangebench.data.utils import get_dataset_stats
from lagrangebench.defaults import defaults
Expand Down Expand Up @@ -64,7 +64,7 @@ def case_builder(
box: Tuple[float, float, float],
metadata: Dict,
input_seq_length: int,
cfg_neighbors: DictConfig = defaults.neighbors,
cfg_neighbors: Union[Dict, DictConfig] = defaults.neighbors,
isotropic_norm: bool = defaults.model.isotropic_norm,
noise_std: float = defaults.train.noise_std,
external_force_fn: Optional[Callable] = None,
Expand All @@ -91,6 +91,9 @@ def case_builder(
magnitude_features: Whether to add velocity magnitudes in the features.
dtype: Data type.
"""
if isinstance(cfg_neighbors, Dict):
cfg_neighbors = OmegaConf.create(cfg_neighbors)

Check warning on line 95 in lagrangebench/case_setup/case.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/case_setup/case.py#L95

Added line #L95 was not covered by tests

normalization_stats = get_dataset_stats(metadata, isotropic_norm, noise_std)

# apply PBC in all directions or not at all
Expand Down
9 changes: 6 additions & 3 deletions lagrangebench/evaluate/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import pickle
import time
from functools import partial
from typing import Callable, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, Optional, Tuple, Union

import haiku as hk
import jax
import jax.numpy as jnp
import jax_md.partition as partition
from jax import jit, vmap
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from lagrangebench.data import H5Dataset
Expand Down Expand Up @@ -315,7 +315,7 @@ def infer(
params: Optional[hk.Params] = None,
state: Optional[hk.State] = None,
load_checkpoint: Optional[str] = None,
cfg_eval_infer: DictConfig = defaults.eval.infer,
cfg_eval_infer: Union[Dict, DictConfig] = defaults.eval.infer,
rollout_dir: Optional[str] = defaults.eval.rollout_dir,
n_rollout_steps: int = defaults.eval.n_rollout_steps,
seed: int = defaults.main.seed,
Expand All @@ -342,6 +342,9 @@ def infer(
params is not None or load_checkpoint is not None
), "Either params or a load_checkpoint directory must be provided for inference."

if isinstance(cfg_eval_infer, Dict):
cfg_eval_infer = OmegaConf.create(cfg_eval_infer)

Check warning on line 346 in lagrangebench/evaluate/rollout.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/evaluate/rollout.py#L346

Added line #L346 was not covered by tests

n_trajs = cfg_eval_infer.n_trajs
if n_trajs == -1:
n_trajs = data_test.num_samples

Check warning on line 350 in lagrangebench/evaluate/rollout.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/evaluate/rollout.py#L350

Added line #L350 was not covered by tests
Expand Down
6 changes: 4 additions & 2 deletions lagrangebench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path as osp
from argparse import Namespace
from datetime import datetime
from typing import Callable, Dict, Optional, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type, Union

import haiku as hk
import jax
Expand All @@ -23,7 +23,9 @@
from lagrangebench.utils import NodeType


def train_or_infer(cfg: DictConfig):
def train_or_infer(cfg: Union[Dict, DictConfig]):
if isinstance(cfg, Dict):
cfg = OmegaConf.create(cfg)

Check warning on line 28 in lagrangebench/runner.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/runner.py#L28

Added line #L28 was not covered by tests
# sanity check on the passed configs
check_cfg(cfg)

Expand Down
36 changes: 16 additions & 20 deletions lagrangebench/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Training utils and functions."""

import os
from dataclasses import dataclass
from collections import namedtuple
from functools import partial
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple, Union

import haiku as hk
import jax
Expand Down Expand Up @@ -32,18 +32,6 @@
from .strats import push_forward_build, push_forward_sample_steps


@dataclass(frozen=True)
class LossConfig:
"""Weights for the different targets in the loss function."""

pos: float = 0.0
vel: float = 0.0
acc: float = 1.0

def __getitem__(self, key):
return getattr(self, key)


@partial(jax.jit, static_argnames=["model_fn", "loss_weight"])
def _mse(
params: hk.Params,
Expand All @@ -63,7 +51,8 @@ def _mse(
# loss components
losses = []
for t in pred:
losses.append((loss_weight[t] * (pred[t] - target[t]) ** 2).sum(axis=-1))
w = getattr(loss_weight, t)
losses.append((w * (pred[t] - target[t]) ** 2).sum(axis=-1))
total_loss = jnp.array(losses).sum(0)
total_loss = jnp.where(non_kinematic_mask, total_loss, 0)
total_loss = total_loss.sum() / num_non_kinematic
Expand Down Expand Up @@ -107,9 +96,9 @@ def __init__(
case,
data_train: H5Dataset,
data_valid: H5Dataset,
cfg_train: DictConfig = defaults.train,
cfg_eval: DictConfig = defaults.eval,
cfg_logging: DictConfig = defaults.logging,
cfg_train: Union[Dict, DictConfig] = defaults.train,
cfg_eval: Union[Dict, DictConfig] = defaults.eval,
cfg_logging: Union[Dict, DictConfig] = defaults.logging,
input_seq_length: int = defaults.model.input_seq_length,
seed: int = defaults.main.seed,
**kwargs,
Expand Down Expand Up @@ -152,6 +141,13 @@ def __init__(
f"({cfg_eval.train.n_trajs} > {data_valid.num_samples})"
)

if isinstance(cfg_train, Dict):
cfg_train = OmegaConf.create(cfg_train)

Check warning on line 145 in lagrangebench/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/train/trainer.py#L145

Added line #L145 was not covered by tests
if isinstance(cfg_eval, Dict):
cfg_eval = OmegaConf.create(cfg_eval)

Check warning on line 147 in lagrangebench/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/train/trainer.py#L147

Added line #L147 was not covered by tests
if isinstance(cfg_logging, Dict):
cfg_logging = OmegaConf.create(cfg_logging)

Check warning on line 149 in lagrangebench/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/train/trainer.py#L149

Added line #L149 was not covered by tests

self.model = model
self.case = case
self.input_seq_length = input_seq_length
Expand All @@ -168,8 +164,8 @@ def __init__(
self.wandb_config["eval"]["train"]["n_trajs"] = self.cfg_eval.train.n_trajs

Check warning on line 164 in lagrangebench/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

lagrangebench/train/trainer.py#L163-L164

Added lines #L163 - L164 were not covered by tests

# make immutable for jitting
# TODO look for simpler alternatives to LossConfig
self.loss_weight = LossConfig(**dict(self.cfg_train.loss_weight))
loss_weight = self.cfg_train.loss_weight
self.loss_weight = namedtuple("loss_weight", loss_weight)(**loss_weight)

self.base_key, seed_worker, generator = set_seed(seed)

Expand Down

0 comments on commit ecbe8fc

Please sign in to comment.