From 2f4f4f61de48c88a8f179909453b76ca9922beb9 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Wed, 10 Jan 2024 23:05:35 +0100 Subject: [PATCH 1/8] move force_fn to a file in dataset_path --- experiments/utils.py | 29 +---------------------------- lagrangebench/data/data.py | 32 ++++++++++++-------------------- 2 files changed, 13 insertions(+), 48 deletions(-) diff --git a/experiments/utils.py b/experiments/utils.py index 63b14b7..8168178 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -55,34 +55,7 @@ def setup_data(args: Namespace) -> Tuple[H5Dataset, H5Dataset, Namespace]: f"exceeds eval_n_trajs ({args.config.eval_n_trajs})" ) - # TODO: move this to a more suitable place - if "RPF" in args.info.dataset_name.upper(): - args.info.has_external_force = True - if data_train.metadata["dim"] == 2: - - def external_force_fn(position): - return jnp.where( - position[1] > 1.0, - jnp.array([-1.0, 0.0]), - jnp.array([1.0, 0.0]), - ) - - elif data_train.metadata["dim"] == 3: - - def external_force_fn(position): - return jnp.where( - position[1] > 1.0, - jnp.array([-1.0, 0.0, 0.0]), - jnp.array([1.0, 0.0, 0.0]), - ) - - else: - args.info.has_external_force = False - external_force_fn = None - - data_train.external_force_fn = external_force_fn - data_valid.external_force_fn = external_force_fn - data_test.external_force_fn = external_force_fn + args.info.has_external_force = bool(data_train.external_force_fn is not None) return data_train, data_valid, data_test, args diff --git a/lagrangebench/data/data.py b/lagrangebench/data/data.py index 1c976bd..61de9f5 100644 --- a/lagrangebench/data/data.py +++ b/lagrangebench/data/data.py @@ -1,12 +1,13 @@ """Dataset modules for loading HDF5 simulation trajectories.""" import bisect +import importlib import json import os import os.path as osp import re import zipfile -from typing import Callable, Optional +from typing import Optional import h5py import jax.numpy as jnp @@ -45,7 +46,6 @@ def __init__( input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", - external_force_fn: Optional[Callable] = None, ): """Initialize the dataset. If the dataset is not present, it is downloaded. @@ -61,7 +61,6 @@ def __init__( unroll steps. During validation/testing, this specifies the largest N-step MSE loss we are interested in, e.g. for best model checkpointing. nl_backend: Which backend to use for the neighbor list - external_force_fn: Function that returns the position-wise external force """ if dataset_path.endswith("/"): # remove trailing slash in dataset path @@ -83,7 +82,16 @@ def __init__( self.input_seq_length = input_seq_length self.nl_backend = nl_backend - self.external_force_fn = external_force_fn + force_fn_path = osp.join(dataset_path, "force.py") + if osp.exists(force_fn_path): + # load force_fn if `force.py` exists in dataset_path + spec = importlib.util.spec_from_file_location("force_module", force_fn_path) + force_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(force_module) + + self.external_force_fn = force_module.force_fn + else: + self.external_force_fn = None # load dataset metadata with open(osp.join(dataset_path, "metadata.json"), "r") as f: @@ -328,13 +336,6 @@ def __init__( extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): - def external_force_fn(position): - return jnp.where( - position[1] > 1.0, - jnp.array([-1.0, 0.0]), - jnp.array([1.0, 0.0]), - ) - super().__init__( split, dataset_path, @@ -342,7 +343,6 @@ def external_force_fn(position): input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, - external_force_fn=external_force_fn, ) @@ -357,13 +357,6 @@ def __init__( extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): - def external_force_fn(position): - return jnp.where( - position[1] > 1.0, - jnp.array([-1.0, 0.0, 0.0]), - jnp.array([1.0, 0.0, 0.0]), - ) - super().__init__( split, dataset_path, @@ -371,7 +364,6 @@ def external_force_fn(position): input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, - external_force_fn=external_force_fn, ) From 0a4a6fce5e12f4078a450e9bca4edadf7b3ec506 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Wed, 10 Jan 2024 23:17:43 +0100 Subject: [PATCH 2/8] add warning if neighbor_list_multiplier < 1.25 --- lagrangebench/case_setup/case.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lagrangebench/case_setup/case.py b/lagrangebench/case_setup/case.py index 0925d2d..21ee2ec 100644 --- a/lagrangebench/case_setup/case.py +++ b/lagrangebench/case_setup/case.py @@ -1,5 +1,6 @@ """Case setup functions.""" +import warnings from typing import Callable, Dict, Optional, Tuple, Union import jax.numpy as jnp @@ -101,6 +102,14 @@ def case_builder( displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0)) + if neighbor_list_multiplier < 1.25: + warnings.warn( + f"neighbor_list_multiplier={neighbor_list_multiplier} < 1.25 is very low. " + "Be especially cautious if you batch training and/or inference as " + "reallocation might be necessary based on different overflow conditions. " + "See https://github.com/tumaer/lagrangebench/pull/20#discussion_r1443811262" + ) + neighbor_fn = neighbor_list( displacement_fn, jnp.array(box), From f5a2eeb8142f9727b2c81a665d8984c1163adee2 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Wed, 10 Jan 2024 23:20:21 +0100 Subject: [PATCH 3/8] run pytest on python 3.9, 3.10, 3.11 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c2b1a24..47a8fc3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 From 6bd016a706c60efd6d880e9bc89d2c3256281db5 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Thu, 11 Jan 2024 08:01:42 +0000 Subject: [PATCH 4/8] add --cov-fail-under=50 --- .github/workflows/ruff.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 563b87d..9558ab8 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -1,5 +1,5 @@ name: Ruff -on: [push, pull_request] +on: [pull_request] jobs: ruff: runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index e2d9371..43f3a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ select = [ [tool.pytest.ini_options] testpaths = "tests/" -addopts = "--cov=lagrangebench" +addopts = "--cov=lagrangebench --cov-fail-under=50" filterwarnings = [ # ignore all deprecation warnings except from lagrangebench "ignore::DeprecationWarning:^(?!.*lagrangebench).*" From cccfd07d1de28a00b8cdde3996d5fb2a96c8642b Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Thu, 11 Jan 2024 09:30:12 +0000 Subject: [PATCH 5/8] add .codecov.yaml --- .codecov.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..c5eb4db --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,11 @@ +coverage: + range: "50...100" + status: + project: + default: + target: 60% + threshold: 5% + patch: + default: + target: 50% + threshold: 5% \ No newline at end of file From a1981406633726231f565d7836a43a741e8d7086 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Thu, 11 Jan 2024 09:43:59 +0000 Subject: [PATCH 6/8] updates to .codecov.yml --- .codecov.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index c5eb4db..c9ce4b4 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,10 +1,11 @@ coverage: - range: "50...100" + range: 50..70 # red color under 50%, yellow at 50%..70%, green over 70% + precision: 1 status: project: default: - target: 60% - threshold: 5% + target: 60% # coverage success only above X% + threshold: 5% # allow the coverage to drop by X% and being a success patch: default: target: 50% From 0e4908e842d62ff10bf96dc09542efee72781091 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Thu, 11 Jan 2024 10:49:17 +0000 Subject: [PATCH 7/8] add printing configs in main.py --- main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/main.py b/main.py index 9d679d8..bc25a09 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +import pprint from argparse import Namespace import yaml @@ -18,6 +19,9 @@ # priority to command line arguments args.update(cli_args) args = Namespace(config=Namespace(**args), info=Namespace()) + print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") + pprint.pprint(vars(args.config)) + print("#" * 79) # specify cuda device os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow From bd433c255d92f4db8f977c12f79a44af78ea2f1f Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Thu, 11 Jan 2024 17:21:54 +0000 Subject: [PATCH 8/8] update zenodo URLs to v2 --- lagrangebench/data/data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lagrangebench/data/data.py b/lagrangebench/data/data.py index 61de9f5..0febebe 100644 --- a/lagrangebench/data/data.py +++ b/lagrangebench/data/data.py @@ -18,13 +18,13 @@ from lagrangebench.utils import NodeType URLS = { - "tgv2d": "https://zenodo.org/records/10021926/files/2D_TGV_2500_10kevery100.zip", - "rpf2d": "https://zenodo.org/records/10021926/files/2D_RPF_3200_20kevery100.zip", - "ldc2d": "https://zenodo.org/records/10021926/files/2D_LDC_2708_10kevery100.zip", - "dam2d": "https://zenodo.org/records/10021926/files/2D_DAM_5740_20kevery100.zip", - "tgv3d": "https://zenodo.org/records/10021926/files/3D_TGV_8000_10kevery100.zip", - "rpf3d": "https://zenodo.org/records/10021926/files/3D_RPF_8000_10kevery100.zip", - "ldc3d": "https://zenodo.org/records/10021926/files/3D_LDC_8160_10kevery100.zip", + "tgv2d": "https://zenodo.org/records/10491868/files/2D_TGV_2500_10kevery100.zip", + "rpf2d": "https://zenodo.org/records/10491868/files/2D_RPF_3200_20kevery100.zip", + "ldc2d": "https://zenodo.org/records/10491868/files/2D_LDC_2708_10kevery100.zip", + "dam2d": "https://zenodo.org/records/10491868/files/2D_DAM_5740_20kevery100.zip", + "tgv3d": "https://zenodo.org/records/10491868/files/3D_TGV_8000_10kevery100.zip", + "rpf3d": "https://zenodo.org/records/10491868/files/3D_RPF_8000_10kevery100.zip", + "ldc3d": "https://zenodo.org/records/10491868/files/3D_LDC_8160_10kevery100.zip", }