Skip to content

Commit

Permalink
Merge pull request #23 from tumaer/force_files
Browse files Browse the repository at this point in the history
Force files and some small fixes
  • Loading branch information
arturtoshev authored Jan 11, 2024
2 parents 1979e90 + bd433c2 commit 28c8f9e
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 58 deletions.
12 changes: 12 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
coverage:
range: 50..70 # red color under 50%, yellow at 50%..70%, green over 70%
precision: 1
status:
project:
default:
target: 60% # coverage success only above X%
threshold: 5% # allow the coverage to drop by X% and being a success
patch:
default:
target: 50%
threshold: 5%
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Ruff
on: [push, pull_request]
on: [pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
29 changes: 1 addition & 28 deletions experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,34 +55,7 @@ def setup_data(args: Namespace) -> Tuple[H5Dataset, H5Dataset, Namespace]:
f"exceeds eval_n_trajs ({args.config.eval_n_trajs})"
)

# TODO: move this to a more suitable place
if "RPF" in args.info.dataset_name.upper():
args.info.has_external_force = True
if data_train.metadata["dim"] == 2:

def external_force_fn(position):
return jnp.where(
position[1] > 1.0,
jnp.array([-1.0, 0.0]),
jnp.array([1.0, 0.0]),
)

elif data_train.metadata["dim"] == 3:

def external_force_fn(position):
return jnp.where(
position[1] > 1.0,
jnp.array([-1.0, 0.0, 0.0]),
jnp.array([1.0, 0.0, 0.0]),
)

else:
args.info.has_external_force = False
external_force_fn = None

data_train.external_force_fn = external_force_fn
data_valid.external_force_fn = external_force_fn
data_test.external_force_fn = external_force_fn
args.info.has_external_force = bool(data_train.external_force_fn is not None)

return data_train, data_valid, data_test, args

Expand Down
9 changes: 9 additions & 0 deletions lagrangebench/case_setup/case.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Case setup functions."""

import warnings
from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
Expand Down Expand Up @@ -101,6 +102,14 @@ def case_builder(

displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0))

if neighbor_list_multiplier < 1.25:
warnings.warn(
f"neighbor_list_multiplier={neighbor_list_multiplier} < 1.25 is very low. "
"Be especially cautious if you batch training and/or inference as "
"reallocation might be necessary based on different overflow conditions. "
"See https://github.com/tumaer/lagrangebench/pull/20#discussion_r1443811262"
)

neighbor_fn = neighbor_list(
displacement_fn,
jnp.array(box),
Expand Down
46 changes: 19 additions & 27 deletions lagrangebench/data/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Dataset modules for loading HDF5 simulation trajectories."""

import bisect
import importlib
import json
import os
import os.path as osp
import re
import zipfile
from typing import Callable, Optional
from typing import Optional

import h5py
import jax.numpy as jnp
Expand All @@ -17,13 +18,13 @@
from lagrangebench.utils import NodeType

URLS = {
"tgv2d": "https://zenodo.org/records/10021926/files/2D_TGV_2500_10kevery100.zip",
"rpf2d": "https://zenodo.org/records/10021926/files/2D_RPF_3200_20kevery100.zip",
"ldc2d": "https://zenodo.org/records/10021926/files/2D_LDC_2708_10kevery100.zip",
"dam2d": "https://zenodo.org/records/10021926/files/2D_DAM_5740_20kevery100.zip",
"tgv3d": "https://zenodo.org/records/10021926/files/3D_TGV_8000_10kevery100.zip",
"rpf3d": "https://zenodo.org/records/10021926/files/3D_RPF_8000_10kevery100.zip",
"ldc3d": "https://zenodo.org/records/10021926/files/3D_LDC_8160_10kevery100.zip",
"tgv2d": "https://zenodo.org/records/10491868/files/2D_TGV_2500_10kevery100.zip",
"rpf2d": "https://zenodo.org/records/10491868/files/2D_RPF_3200_20kevery100.zip",
"ldc2d": "https://zenodo.org/records/10491868/files/2D_LDC_2708_10kevery100.zip",
"dam2d": "https://zenodo.org/records/10491868/files/2D_DAM_5740_20kevery100.zip",
"tgv3d": "https://zenodo.org/records/10491868/files/3D_TGV_8000_10kevery100.zip",
"rpf3d": "https://zenodo.org/records/10491868/files/3D_RPF_8000_10kevery100.zip",
"ldc3d": "https://zenodo.org/records/10491868/files/3D_LDC_8160_10kevery100.zip",
}


Expand All @@ -45,7 +46,6 @@ def __init__(
input_seq_length: int = 6,
extra_seq_length: int = 0,
nl_backend: str = "jaxmd_vmap",
external_force_fn: Optional[Callable] = None,
):
"""Initialize the dataset. If the dataset is not present, it is downloaded.
Expand All @@ -61,7 +61,6 @@ def __init__(
unroll steps. During validation/testing, this specifies the largest
N-step MSE loss we are interested in, e.g. for best model checkpointing.
nl_backend: Which backend to use for the neighbor list
external_force_fn: Function that returns the position-wise external force
"""

if dataset_path.endswith("/"): # remove trailing slash in dataset path
Expand All @@ -83,7 +82,16 @@ def __init__(
self.input_seq_length = input_seq_length
self.nl_backend = nl_backend

self.external_force_fn = external_force_fn
force_fn_path = osp.join(dataset_path, "force.py")
if osp.exists(force_fn_path):
# load force_fn if `force.py` exists in dataset_path
spec = importlib.util.spec_from_file_location("force_module", force_fn_path)
force_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(force_module)

self.external_force_fn = force_module.force_fn
else:
self.external_force_fn = None

# load dataset metadata
with open(osp.join(dataset_path, "metadata.json"), "r") as f:
Expand Down Expand Up @@ -328,21 +336,13 @@ def __init__(
extra_seq_length: int = 0,
nl_backend: str = "jaxmd_vmap",
):
def external_force_fn(position):
return jnp.where(
position[1] > 1.0,
jnp.array([-1.0, 0.0]),
jnp.array([1.0, 0.0]),
)

super().__init__(
split,
dataset_path,
name="rpf2d",
input_seq_length=input_seq_length,
extra_seq_length=extra_seq_length,
nl_backend=nl_backend,
external_force_fn=external_force_fn,
)


Expand All @@ -357,21 +357,13 @@ def __init__(
extra_seq_length: int = 0,
nl_backend: str = "jaxmd_vmap",
):
def external_force_fn(position):
return jnp.where(
position[1] > 1.0,
jnp.array([-1.0, 0.0, 0.0]),
jnp.array([1.0, 0.0, 0.0]),
)

super().__init__(
split,
dataset_path,
name="rpf3d",
input_seq_length=input_seq_length,
extra_seq_length=extra_seq_length,
nl_backend=nl_backend,
external_force_fn=external_force_fn,
)


Expand Down
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pprint
from argparse import Namespace

import yaml
Expand All @@ -18,6 +19,9 @@
# priority to command line arguments
args.update(cli_args)
args = Namespace(config=Namespace(**args), info=Namespace())
print("#" * 79, "\nStarting a LagrangeBench run with the following configs:")
pprint.pprint(vars(args.config))
print("#" * 79)

# specify cuda device
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ select = [

[tool.pytest.ini_options]
testpaths = "tests/"
addopts = "--cov=lagrangebench"
addopts = "--cov=lagrangebench --cov-fail-under=50"
filterwarnings = [
# ignore all deprecation warnings except from lagrangebench
"ignore::DeprecationWarning:^(?!.*lagrangebench).*"
Expand Down

0 comments on commit 28c8f9e

Please sign in to comment.