Skip to content

Commit

Permalink
Merge pull request #31 from tumaer/cfg_dataset
Browse files Browse the repository at this point in the history
Cfg dataset; better assertions and comments
  • Loading branch information
arturtoshev authored Jul 1, 2024
2 parents 6428cff + 9f1f397 commit 80d549a
Show file tree
Hide file tree
Showing 20 changed files with 68 additions and 38 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ To run JAX on GPU, follow [Installing JAX](https://jax.readthedocs.io/en/latest/
pip install -U "jax[cuda12]==0.4.29"
```

> Note: as of 27.06.2024, to make our GNN models **deterministic** on GPUs, you need to set `os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"`. However, all current models rely of `scatter_sum`, and this operation seems to be slower than running a normal for-loop in Python, when executed in deterministic mode, see [#17844](https://github.com/google/jax/issues/17844) and [#10674](https://github.com/google/jax/discussions/10674).

### MacOS
Currently, only the CPU installation works. You will need to change a few small things to get it going:
- Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`.
Expand Down Expand Up @@ -121,7 +124,7 @@ We provide three notebooks that show LagrangeBench functionalities, namely:
- [`gns_data.ipynb`](notebooks/gns_data.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/gns_data.ipynb), showing how to train models within LagrangeBench on the datasets from the paper [Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405).

## Datasets
The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https://zenodo.org/doi/10.5281/zenodo.10021925). If a dataset is not found in `dataset_path`, the data is automatically downloaded. Alternatively, to manually download the datasets use the `download_data.sh` shell script, either with a specific dataset name or "all". Namely
The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https://zenodo.org/doi/10.5281/zenodo.10021925). If a dataset is not found in `dataset.src`, the data is automatically downloaded. Alternatively, to manually download the datasets use the `download_data.sh` shell script, either with a specific dataset name or "all". Namely
- __Taylor Green Vortex 2D__: `bash download_data.sh tgv_2d datasets/`
- __Reverse Poiseuille Flow 2D__: `bash download_data.sh rpf_2d datasets/`
- __Lid Driven Cavity 2D__: `bash download_data.sh ldc_2d datasets/`
Expand Down
4 changes: 2 additions & 2 deletions configs/WaterDrop_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

main:
dataset_path: /tmp/datasets/WaterDrop
dataset:
src: /tmp/datasets/WaterDrop

model:
name: gns
Expand Down
3 changes: 2 additions & 1 deletion configs/dam_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/2D_DAM_5740_20kevery100
dataset:
src: datasets/2D_DAM_5740_20kevery100

logging:
wandb_project: dam_2d
Expand Down
3 changes: 2 additions & 1 deletion configs/ldc_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/2D_LDC_2708_10kevery100
dataset:
src: datasets/2D_LDC_2708_10kevery100

logging:
wandb_project: ldc_2d
Expand Down
3 changes: 2 additions & 1 deletion configs/ldc_3d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/3D_LDC_8160_10kevery100
dataset:
src: datasets/3D_LDC_8160_10kevery100

logging:
wandb_project: ldc_3d
Expand Down
3 changes: 2 additions & 1 deletion configs/rpf_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/2D_RPF_3200_20kevery100
dataset:
src: datasets/2D_RPF_3200_20kevery100

logging:
wandb_project: rpf_2d
3 changes: 2 additions & 1 deletion configs/rpf_3d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/3D_RPF_8000_10kevery100
dataset:
src: datasets/3D_RPF_8000_10kevery100

logging:
wandb_project: rpf_3d
3 changes: 2 additions & 1 deletion configs/tgv_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/2D_TGV_2500_10kevery100
dataset:
src: datasets/2D_TGV_2500_10kevery100

logging:
wandb_project: tgv_2d
3 changes: 2 additions & 1 deletion configs/tgv_3d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extends: LAGRANGEBENCH_DEFAULTS

dataset_path: datasets/3D_TGV_8000_10kevery100
dataset:
src: datasets/3D_TGV_8000_10kevery100

logging:
wandb_project: tgv_3d
12 changes: 4 additions & 8 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
.. ![rpf2d.gif](https://s11.gifyu.com/images/Sce92.gif)
.. ![rpf3d.gif](https://s11.gifyu.com/images/Sce3X.gif)
.. <img src="https://s11.gifyu.com/images/Sce92.gif" width="40" height="40" />
LagrangeBench
=============

.. image:: https://s11.gifyu.com/images/Sce92.gif
:alt: Funny GIF
.. image:: https://drive.google.com/thumbnail?id=1rP0pf1KL8iGbly0tA0qthUE_tMDv_9Jp&sz=w1000
:alt: rpf2d.gif

.. image:: https://s11.gifyu.com/images/Sce3X.gif
:alt: Funny GIF2
.. image:: https://drive.google.com/thumbnail?id=1BMGkHj9EYMGUOdsE5QwiJWCTvDNqveHc&sz=w1000
:alt: rpf3d.gif


What is ``LagrangeBench``?
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ cloudpickle
dm_haiku>=0.0.10
e3nn_jax==0.20.3
h5py
jax[cpu]==0.4.29
jax-sph>=0.0.3
jax[cpu]==0.4.29
jmp>=0.0.4
jraph>=0.0.6.dev0
matscipy>=0.8.0
Expand Down
8 changes: 7 additions & 1 deletion lagrangebench/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(
Args:
split: "train", "valid", or "test"
dataset_path: Path to the dataset
dataset_path: Path to the dataset. Download will start automatically if
dataset_path does not exist.
name: Name of the dataset. If None, it is inferred from the path.
input_seq_length: Length of the input sequence. The number of historic
velocities is input_seq_length - 1. And during training, the returned
Expand Down Expand Up @@ -92,6 +93,11 @@ def __init__(

self.external_force_fn = force_module.force_fn
else:
if self.name in ["dam2d", "rpf2d", "rpf3d"]:
raise FileNotFoundError(
f"External force function not found in {dataset_path}. "
"Download the latest LagrangeBench dataset from Zenodo."
)
self.external_force_fn = None

# load dataset metadata
Expand Down
12 changes: 9 additions & 3 deletions lagrangebench/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
cfg.load_ckp = None
# One of "train", "infer" or "all" (= both)
cfg.mode = "all"
# path to data directory
cfg.dataset_path = None
# random seed
cfg.seed = 0
# data type for preprocessing. One of "float32" or "float64"
Expand All @@ -28,6 +26,14 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
# Should be specified before importing the library.
cfg.xla_mem_fraction = None

### dataset
cfg.dataset = OmegaConf.create({})

# path to data directory
cfg.dataset.src = None
# dataset name
cfg.dataset.name = None

### model
cfg.model = OmegaConf.create({})

Expand Down Expand Up @@ -178,7 +184,7 @@ def check_cfg(cfg: DictConfig):

assert cfg.mode in ["train", "infer", "all"]
assert cfg.dtype in ["float32", "float64"]
assert cfg.dataset_path is not None, "dataset_path must be specified."
assert cfg.dataset.src is not None, "dataset.src must be specified."

assert cfg.model.input_seq_length >= 2, "At least two positions for one past vel."

Expand Down
9 changes: 6 additions & 3 deletions lagrangebench/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import os.path as osp
from argparse import Namespace
from datetime import datetime
from typing import Callable, Dict, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -144,8 +143,9 @@ def train_or_infer(cfg: Union[Dict, DictConfig]):
return 0


def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]:
dataset_path = cfg.dataset_path
def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, H5Dataset]:
dataset_path = cfg.dataset.src
dataset_name = cfg.dataset.name
ckp_dir = cfg.logging.ckp_dir
rollout_dir = cfg.eval.rollout_dir
input_seq_length = cfg.model.input_seq_length
Expand All @@ -164,20 +164,23 @@ def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]:
data_train = H5Dataset(
"train",
dataset_path=dataset_path,
name=dataset_name,
input_seq_length=input_seq_length,
extra_seq_length=cfg.train.pushforward.unrolls[-1],
nl_backend=nl_backend,
)
data_valid = H5Dataset(
"valid",
dataset_path=dataset_path,
name=dataset_name,
input_seq_length=input_seq_length,
extra_seq_length=n_rollout_steps,
nl_backend=nl_backend,
)
data_test = H5Dataset(
"test",
dataset_path=dataset_path,
name=dataset_name,
input_seq_length=input_seq_length,
extra_seq_length=n_rollout_steps,
nl_backend=nl_backend,
Expand Down
4 changes: 3 additions & 1 deletion lagrangebench/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def train(
update_fn = partial(_update, loss_fn=loss_fn, opt_update=self.opt_update)

# init values
pos_input_and_target, particle_type = next(iter(loader_train))
raw_batch = next(iter(loader_train))
raw_batch = jax.tree_map(lambda x: jnp.array(x), raw_batch) # numpy to jax
pos_input_and_target, particle_type = raw_batch
raw_sample = (pos_input_and_target[0], particle_type[0])
key, features, _, neighbors = case.allocate(self.base_key, raw_sample)

Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def load_embedded_configs(config_path: str, cli_args: DictConfig) -> DictConfig:
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction)

# The following line makes the code deterministic on GPUs, but also extremely slow.
# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"

cfg = load_embedded_configs(config_path, cli_args)

print("#" * 79, "\nStarting a LagrangeBench run with the following configs:")
Expand Down
14 changes: 7 additions & 7 deletions notebooks/datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,22 @@
"\n",
"- Reverse Poiseuille Flow\n",
"\n",
" ![rpf2d.gif](https://s11.gifyu.com/images/Sce92.gif)\n",
" ![rpf3d.gif](https://s11.gifyu.com/images/Sce3X.gif)\n",
" <img src=\"https://drive.google.com/thumbnail?id=1rP0pf1KL8iGbly0tA0qthUE_tMDv_9Jp&sz=w1000\" alt=\"rpf2d.gif\">\n",
" <img src=\"https://drive.google.com/thumbnail?id=1BMGkHj9EYMGUOdsE5QwiJWCTvDNqveHc&sz=w1000\" alt=\"rpf3d.gif\">\n",
"\n",
"- Taylor Green Vortex\n",
"\n",
" ![tgv2d.gif](https://s11.gifyu.com/images/Sce9b.gif)\n",
" ![tgv3d.gif](https://s11.gifyu.com/images/Sce9z.gif)\n",
" <img src=\"https://drive.google.com/thumbnail?id=1VmEhgtVJBhSGxSrIQUcU8VANFB3ztT5C&sz=w1000\" alt=\"tgv2d.gif\">\n",
" <img src=\"https://drive.google.com/thumbnail?id=1JnbfgqQGs8WIkvDyEPstkpsB7vNEETyr&sz=w1000\" alt=\"tgv3d.gif\">\n",
"\n",
"- Lid-Driven Cavity\n",
"\n",
" ![ldc2d.gif](https://s11.gifyu.com/images/Sce9S.gif)\n",
" ![ldc3d.gif](https://s11.gifyu.com/images/Sce3e.gif)\n",
" <img src=\"https://drive.google.com/thumbnail?id=1Me-8A4wCN_NCoP6w1nteu5TDTROCbWdk&sz=w1000\" alt=\"ldc2d.gif\">\n",
" <img src=\"https://drive.google.com/thumbnail?id=1302IqVQOxkewHuNxZywtvC8ApClBKYc7&sz=w1000\" alt=\"ldc3d.gif\">\n",
"\n",
"- Dam Break\n",
"\n",
" ![dam2d.gif](https://s11.gifyu.com/images/SceKB.gif)\n"
" <img src=\"https://drive.google.com/thumbnail?id=1ccDuHQsJYM-rwCgzhPrv-usoLM9K5NfI&sz=w1000\" alt=\"dam2d.gif\">"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion requirements_cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ cloudpickle
dm_haiku>=0.0.10
e3nn_jax==0.20.3
h5py
jax[cuda12]==0.4.29
jax-sph>=0.0.3
jax[cuda12]==0.4.29
jmp>=0.0.4
jraph>=0.0.6.dev0
matscipy>=0.8.0
Expand Down
6 changes: 4 additions & 2 deletions tests/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class TestInferBuilder(unittest.TestCase):
def setUp(self):
self.cfg = OmegaConf.create(
{
"dataset_path": "tests/3D_LJ_3_1214every1", # Lennard-Jones dataset
"dataset": {
"src": "tests/3D_LJ_3_1214every1", # Lennard-Jones dataset
},
"model": {
"input_seq_length": 3, # two past velocities
"isotropic_norm": False,
Expand All @@ -42,7 +44,7 @@ def setUp(self):

data_valid = H5Dataset(
split="valid",
dataset_path=self.cfg.dataset_path,
dataset_path=self.cfg.dataset.src,
name="lj3d",
input_seq_length=self.cfg.model.input_seq_length,
extra_seq_length=self.cfg.eval.n_rollout_steps,
Expand Down
4 changes: 3 additions & 1 deletion tests/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def setUp(self):
self.cfg = OmegaConf.create(
{
"mode": "all",
"dataset_path": "tests/3D_LJ_3_1214every1",
"dataset": {
"src": "tests/3D_LJ_3_1214every1",
},
"model": {
"name": "linear",
"input_seq_length": 3,
Expand Down

0 comments on commit 80d549a

Please sign in to comment.