Skip to content

Commit

Permalink
update readme and requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
arturtoshev committed Feb 24, 2024
1 parent f1747f6 commit e00711d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
27 changes: 15 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pip install --upgrade jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/
### MacOS
Currently, only the CPU installation works. You will need to change a few small things to get it going:
- Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`.
- Configs: You will need to set `f32: True` and `num_workers: 0` in the `configs/` files.
- Configs: You will need to set `dtype=float32` and `train.num_workers=0`.

Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-metal/) supports jax in general, there seems to be a missing feature used by `jax-md` related to padding -> see [this issue](https://github.com/google/jax/issues/16366#issuecomment-1591085071).

Expand All @@ -85,37 +85,37 @@ A general tutorial is provided in the example notebook "Training GNS on the 2D T
### Running in a local clone (`main.py`)
Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).

When loading a saved model with `--load_ckp` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`experiments/`](experiments/) directory and the [`run.py`](experiments/run.py) file.
When loading a saved model with `load_ckp` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`runner.py`](lagrangebench/runner.py) file.

**Train**

For example, to start a _GNS_ run from scratch on the RPF 2D dataset use
```
python main.py --config configs/rpf_2d/gns.yaml
python main.py config=configs/rpf_2d/gns.yaml
```
Some model presets can be found in `./configs/`.

If `--mode=all`, then training (`--mode=train`) and subsequent inference (`--mode=infer`) on the test split will be run in one go.
If `mode=all` is provided, then training (`mode=train`) and subsequent inference (`mode=infer`) on the test split will be run in one go.


**Restart training**

To restart training from the last checkpoint in `--load_ckp` use
To restart training from the last checkpoint in `load_ckp` use
```
python main.py --load_ckp ckp/gns_rpf2d_yyyymmdd-hhmmss
python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss
```

**Inference**

To evaluate a trained model from `--load_ckp` on the test split (`--test`) use
To evaluate a trained model from `load_ckp` on the test split (`test=True`) use
```
python main.py --load_ckp ckp/gns_rpf2d_yyyymmdd-hhmmss/best --rollout_dir rollout/gns_rpf2d_yyyymmdd-hhmmss/best --mode infer --test
python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss/best rollout_dir=rollout/gns_rpf2d_yyyymmdd-hhmmss/best mode=infer test=True
```

If the default `--out_type_infer=pkl` is active, then the generated trajectories and a `metricsYYYY_MM_DD_HH_MM_SS.pkl` file will be written to the `--rollout_dir`. The metrics file contains all `--metrics_infer` properties for each generated rollout.
If the default `eval.infer.out_type=pkl` is active, then the generated trajectories and a `metricsYYYY_MM_DD_HH_MM_SS.pkl` file will be written to `eval.rollout_dir`. The metrics file contains all `eval.infer.metrics` properties for each generated rollout.

## Datasets
The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https://zenodo.org/doi/10.5281/zenodo.10021925). When creating a new dataset instance, the data is automatically downloaded. Alternatively, to manually download them use the `download_data.sh` shell script, either with a specific dataset name or "all". Namely
The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https://zenodo.org/doi/10.5281/zenodo.10021925). If a dataset is not found in `dataset_path`, the data is automatically downloaded. Alternatively, to manually download the datasets use the `download_data.sh` shell script, either with a specific dataset name or "all". Namely
- __Taylor Green Vortex 2D__: `bash download_data.sh tgv_2d datasets/`
- __Reverse Poiseuille Flow 2D__: `bash download_data.sh rpf_2d datasets/`
- __Lid Driven Cavity 2D__: `bash download_data.sh ldc_2d datasets/`
Expand Down Expand Up @@ -144,7 +144,8 @@ We provide three notebooks that show LagrangeBench functionalities, namely:
┃ ┗ 📜utils.py
┣ 📂evaluate # Evaluation and rollout generation tools
┃ ┣ 📜metrics.py
┃ ┗ 📜rollout.py
┃ ┣ 📜rollout.py
┃ ┗ 📜utils.py
┣ 📂models # Baseline models
┃ ┣ 📜base.py # BaseModel class
┃ ┣ 📜egnn.py
Expand All @@ -157,6 +158,7 @@ We provide three notebooks that show LagrangeBench functionalities, namely:
┃ ┣ 📜strats.py # Training tricks
┃ ┗ 📜trainer.py # Trainer method
┣ 📜defaults.py # Default values
┣ 📜runner.py # Runner wrapping training and inference
┗ 📜utils.py
```

Expand Down Expand Up @@ -195,7 +197,7 @@ pytest

### Clone vs Library
LagrangeBench can be installed by cloning the repository or as a standalone library. This offers more flexibility, but it also comes with its disadvantages: the necessity to implement some things twice. If you change any of the following things, make sure to update its counterpart as well:
- General setup in `experiments/` and `notebooks/tutorial.ipynb`
- General setup in `lagrangebench/runner.py` and `notebooks/tutorial.ipynb`
- Configs in `configs/` and `lagrangebench/defaults.py`
- Zenodo URLs in `download_data.sh` and `lagrangebench/data/data.py`
- Dependencies in `pyproject.toml`, `requirements_cuda.txt`, and `docs/requirements.txt`
Expand Down Expand Up @@ -232,3 +234,4 @@ The associated datasets can be cited as:
The following further publications are based on the LagrangeBench codebase:

1. [Learning Lagrangian Fluid Mechanics with E(3)-Equivariant Graph Neural Networks (GSI 2023)](https://arxiv.org/abs/2305.15603), A. P. Toshev, G. Galletti, J. Brandstetter, S. Adami, N. A. Adams
2. [Neural SPH: Improved Neural Modeling of Lagrangian Fluid Dynamics](https://arxiv.org/abs/2402.06275), A. P. Toshev, J. A. Erbesdobler, N. A. Adams, J. Brandstetter
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ jax_md>=0.2.8
jmp>=0.0.4
jraph>=0.0.6.dev0
matscipy>=0.8.0
omegaconf>=2.3.0
optax>=0.1.7
ott-jax>=0.4.2
pyvista
PyYAML
sphinx==7.2.6
sphinx-rtd-theme==1.3.0
toml>=0.10.2
torch==2.1.0+cpu
wandb
wget
2 changes: 2 additions & 0 deletions requirements_cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ jax_md>=0.2.8
jmp>=0.0.4
jraph>=0.0.6.dev0
matscipy>=0.8.0
omegaconf>=2.3.0
optax>=0.1.7
ott-jax>=0.4.2
pyvista
PyYAML
toml>=0.10.2
torch==2.1.0+cpu
wandb
wget
Expand Down

0 comments on commit e00711d

Please sign in to comment.