Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests #21

Merged
merged 17 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Publish to PyPI

on:
release:
types: [created]

jobs:
build-n-publish:
name: Build and publish to PyPI
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install Poetry
run: |
python -m pip install --upgrade pip
pip install poetry
poetry config virtualenvs.in-project true
- name: Install dependencies
run: |
poetry install
- name: Build and publish
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
run: |
poetry version $(git describe --tags --abbrev=0)
poetry add $(cat requirements.txt)
poetry build
poetry publish
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
36 changes: 36 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Tests

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
tests:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Poetry
run: |
python -m pip install --upgrade pip
pip install poetry
poetry config virtualenvs.in-project true
- name: Install dependencies
run: |
poetry install
- name: Run pytest
run: |
.venv/bin/pytest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ venv*/
rollouts
profile
dist
.coverage

# Sphinx documentation
docs/_build/
16 changes: 4 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,9 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: [ --profile, black ]
- repo: https://github.com/ambv/black
rev: 23.3.0
hooks:
- id: black
args: ['--config=./pyproject.toml']
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.265'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.1.8'
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
Comment on lines +28 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not automatically hard fix things as a default

10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ pip install --upgrade jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/
```

### MacOS
Currently, only the CPU installation works. You will need to change a few small things to get it going:
Currently, only the CPU installation works. You will need to change a few small things to get it going:
- Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`.
- Configs: You will need to set `f64: False` and `num_workers: 0` in the `configs/` files.

Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-metal/) supports jax in general, there seems to be a missing feature used by `jax-md` related to padding -> see [this issue](https://github.com/google/jax/issues/16366#issuecomment-1591085071).

## Usage
### Standalone benchmark library
A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating it's performance.
A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance.

### Running in a local clone (`main.py`)
Alternatively, experiments can also be set up with `main.py`, based around extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).
Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).

When loading a saved model with `--model_dir` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`experiments/`](experiments/) directory and the [`run.py`](experiments/run.py) file.

Expand Down Expand Up @@ -94,8 +94,8 @@ The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https


### Notebooks
Whe provide three notebooks that show LagrangeBench functionalities, namely:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) with a general overview of LagrangeBench library, with trainin and evaluation of a simple GNS model,
We provide three notebooks that show LagrangeBench functionalities, namely:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) with a general overview of LagrangeBench library, with training and evaluation of a simple GNS model,
- [`datasets.ipynb`](notebooks/datasets.ipynb) with more details and visualizations on the datasets, and
- [`gns_data.ipynb`](notebooks/gns_data.ipynb) showing how to train models within LagrangeBench on the datasets from the paper [Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405).

Expand Down
2 changes: 2 additions & 0 deletions configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,5 @@ metrics_infer:
metrics_stride_infer: 1
out_type_infer: pkl
eval_n_trajs_infer: -1
# batch size for validation/testing
batch_size_infer: 2
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

cloudpickle
dm_haiku>=0.0.10
e3nn_jax>=0.20.0
e3nn_jax==0.20.3
h5py
jax[cpu]==0.4.20
jax_md>=0.2.8
Expand All @@ -15,6 +15,6 @@ pyvista
PyYAML
sphinx==7.2.6
sphinx-rtd-theme==1.3.0
torch>=2.1.0+cpu
torch==2.1.0+cpu
wandb
wget
6 changes: 4 additions & 2 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import jax.numpy as jnp
import jmp
import numpy as np
import wandb
import yaml

import wandb
from experiments.utils import setup_data, setup_model
from lagrangebench import Trainer, infer
from lagrangebench.case_setup import case_builder
Expand Down Expand Up @@ -123,6 +123,7 @@ def train_or_infer(args: Namespace):
eval_steps=args.config.eval_steps,
metrics_stride=args.config.metrics_stride,
num_workers=args.config.num_workers,
batch_size_infer=args.config.batch_size_infer,
)
_, _, _ = trainer(
step_max=args.config.step_max,
Expand Down Expand Up @@ -150,7 +151,7 @@ def train_or_infer(args: Namespace):
metrics = infer(
model,
case,
data_test,
data_test if args.config.test else data_valid,
load_checkpoint=args.config.model_dir,
metrics=args.config.metrics_infer,
rollout_dir=args.config.rollout_dir,
Expand All @@ -160,6 +161,7 @@ def train_or_infer(args: Namespace):
n_extrap_steps=args.config.n_extrap_steps,
seed=args.config.seed,
metrics_stride=args.config.metrics_stride_infer,
batch_size=args.config.batch_size_infer,
)

split = "test" if args.config.test else "valid"
Expand Down
10 changes: 4 additions & 6 deletions lagrangebench/case_setup/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
from jax import jit, lax, random, vmap
from jax import Array, jit, lax, vmap
from jax_md import space
from jax_md.dataclasses import dataclass, static_field
from jax_md.partition import NeighborList, NeighborListFormat
Expand All @@ -15,16 +15,14 @@
from .features import FeatureDict, TargetDict, physical_feature_builder
from .partition import neighbor_list

TrainCaseOut = Tuple[random.KeyArray, FeatureDict, TargetDict, NeighborList]
TrainCaseOut = Tuple[Array, FeatureDict, TargetDict, NeighborList]
EvalCaseOut = Tuple[FeatureDict, NeighborList]
SampleIn = Tuple[jnp.ndarray, jnp.ndarray]

AllocateFn = Callable[[random.KeyArray, SampleIn, float, int], TrainCaseOut]
AllocateFn = Callable[[Array, SampleIn, float, int], TrainCaseOut]
AllocateEvalFn = Callable[[SampleIn], EvalCaseOut]

PreprocessFn = Callable[
[random.KeyArray, SampleIn, float, NeighborList, int], TrainCaseOut
]
PreprocessFn = Callable[[Array, SampleIn, float, NeighborList, int], TrainCaseOut]
PreprocessEvalFn = Callable[[SampleIn, NeighborList], EvalCaseOut]

IntegrateFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
Expand Down
2 changes: 1 addition & 1 deletion lagrangebench/case_setup/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def scan_body(carry, input):
if not is_sparse(format):
capacity_limit = N - 1 if mask_self else N
elif format is NeighborListFormat.Sparse:
capacity_limit = N * (N - 1) if mask_self else N ** 2
capacity_limit = N * (N - 1) if mask_self else N**2
else:
capacity_limit = N * (N - 1) // 2
if max_occupancy > capacity_limit:
Expand Down
2 changes: 1 addition & 1 deletion lagrangebench/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_window(self, idx: int):
def __getitem__(self, idx: int):
"""
Get a sequence of positions (of size windows) from the dataset at index idx.

Returns:
Array of shape (num_particles_max, input_seq_length + 1, dim). Along axis=1
the position sequence (length input_seq_length) and the last position to
Expand Down
1 change: 1 addition & 0 deletions lagrangebench/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class defaults:
out_type: str = "none" # type of output. None means no rollout is stored
n_extrap_steps: int = 0 # number of extrapolation steps
metrics_stride: int = 10 # stride for e_kin and sinkhorn
batch_size_infer: int = 2 # batch size for validation/testing

# logging
log_steps: int = 1000 # number of steps between logs
Expand Down
2 changes: 1 addition & 1 deletion lagrangebench/evaluate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
metadata: Metadata of the dataset.
loss_ranges: List of horizon lengths to compute the loss for.
input_seq_length: Length of the input sequence.
stride: Rollout subsample frequency for Sinkhorn.
stride: Rollout subsample frequency for e_kin and sinkhorn.
ot_backend: Backend for sinkhorn computation. "ott" or "pot".
"""
if active_metrics is None:
Expand Down
Loading