Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neighbors #13

Merged
merged 9 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@

</div>

JAX-SPH [(Toshev et al., 2024)](https://arxiv.org/abs/2403.04750) is a modular JAX-based weakly compressible SPH framework, which implements the following SPH routines:
- Standard SPH [(Adami et al., 2012)](https://www.sciencedirect.com/science/article/pii/S002199911200229X)
- Transport velocity SPH [(Adami et al., 2013)](https://www.sciencedirect.com/science/article/pii/S002199911300096X)
- Riemann SPH [(Zhang et al., 2017)](https://www.sciencedirect.com/science/article/abs/pii/S0021999117300438)

![HT_T.gif](https://s9.gifyu.com/images/SUwUD.gif)

## Table of Contents

1. [**Installation**](#installation)
1. [**Getting Started**](#getting-started)
1. [**Setting up a case**](#setting-up-a-case)
1. [**Contributing**](#contributing)
1. [**Citation**](#citation)
1. [**Acknowledgements**](#acknowledgements)

## Installation

### Standalone library
Expand Down Expand Up @@ -84,16 +88,18 @@ python main.py config=cases/ht.yaml
```

### Notebooks
We provide four notebooks demonstrating how to use JAX-SPH:
We provide various notebooks demonstrating how to use JAX-SPH:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/tutorial.ipynb), with a general overview of JAX-SPH and an example how to run the channel flow with hot bottom wall.
- [`iclr24_grads.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_grads.ipynb), with a validation of the gradients through the solver.
- [`iclr24_inverse.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_inverse.ipynb), solving the inverse problem of finding the initial state of a 100-step-long SPH simulation.
- [`iclr24_sitl.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_sitl.ipynb), including training and testing a Solver-in-the-Loop model using the [LagrangeBench](https://github.com/tumaer/lagrangebench) library.
- [`iclr24_grads.ipynb`](notebooks/iclr24_grads.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_grads.ipynb), with a validation of the gradients through the solver.
- [`iclr24_inverse.ipynb`](notebooks/iclr24_inverse.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_inverse.ipynb), solving the inverse problem of finding the initial state of a 100-step-long SPH simulation.
- [`iclr24_sitl.ipynb`](notebooks/iclr24_sitl.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_sitl.ipynb), including training and testing a Solver-in-the-Loop model using the [LagrangeBench](https://github.com/tumaer/lagrangebench) library.
- [`neighbors.ipynb`](notebooks/neighbors.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/neighbors.ipynb), explaining the difference between the three neighbor search implementations and comparing their performance.
- [`kernel_plots.ipynb`](notebooks/kernel_plots.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/kernel_plots.ipynb), visualizing the SPH kernels.

## Setting up a case
## Setting up a Case
To set up a case, just add a `my_case.py` and a `my_case.yaml` file to the `cases/` directory. Every *.py case should inherit from `SimulationSetup` in `jax_sph/case_setup.py` or another case, and every *.yaml config file should either contain a complete set of parameters (see `jax_sph/defaults.py`) or extend `JAX_SPH_DEFAULTS`. Running a case in relaxation mode `case.mode=rlx` overwrites certain parts of the selected case. Passed CLI arguments overwrite any argument.

## Development and Contribution
## Contributing
If you wish to contribute, please run
```bash
pre-commit install
Expand Down
30 changes: 26 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,36 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.

Welcome to JAX-SPH's documentation!
===================================
JAX-SPH
========

.. image:: https://s9.gifyu.com/images/SUwUD.gif
:alt: GIF


What is ``JAX-SPH``?
--------------------

JAX-SPH `(Toshev et al., 2024) <https://arxiv.org/abs/2403.04750>`_ is a Smoothed Particle Hydrodynamics (SPH) code written in `JAX <https://jax.readthedocs.io/>`_. JAX-SPH is designed to be simple, fast, and compatible with deep learning workflows. We currently support the following SPH routines:

* Standard SPH `(Adami et al., 2012) <https://www.sciencedirect.com/science/article/pii/S002199911200229X>`_
* Transport velocity SPH `(Adami et al., 2013) <https://www.sciencedirect.com/science/article/pii/S002199911300096X>`_
* Riemann SPH `(Zhang et al., 2017) <https://www.sciencedirect.com/science/article/abs/pii/S0021999117300438>`_

Check out our `GitHub repository <https://github.com/tumaer/jax-sph>`_ for more information including installation instructions and tutorial notebooks.

.. toctree::
:maxdepth: 1
:caption: Getting Started

pages/tutorials
pages/defaults

.. toctree::
:maxdepth: 2
:caption: Contents:
:caption: API

pages/case_setup
pages/solver
pages/simulate
pages/utils
pages/utils
48 changes: 48 additions & 0 deletions docs/pages/defaults.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Defaults
===================================

The defaults are defined through a function ``jax_sph.defaults.set_defaults()``, which
takes a potentially empty ``omegaconf.DictConfig`` object and creates or overwrites the
default values. One can also directly call ``from jax_sph.defaults import defaults``,
with ``defaults=set_defaults()``, to get the default DictConfig, which we unpack below.

.. exec_code::
:hide_code:
:linenos_output:
:language_output: python
:caption: JAX-SPH default values


with open("jax_sph/defaults.py", "r") as file:
defaults_full = file.read()

# parse defaults: remove imports, only keep the set_defaults function

defaults_full = defaults_full.split("\n")

# remove imports
defaults_full = [line for line in defaults_full if not line.startswith("import")]
defaults_full = [line for line in defaults_full if len(line.replace(" ", "")) > 0]

# remove other functions
keep = False
defaults = []
for i, line in enumerate(defaults_full):
if line.startswith("def"):
if "set_defaults" in line:
keep = True
else:
keep = False

if keep:
defaults.append(line)

# remove function declaration and return
defaults = defaults[2:-2]

# remove indent
defaults = [line[4:] for line in defaults]


print("\n".join(defaults))

8 changes: 8 additions & 0 deletions docs/pages/tutorials.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Tutorials
=========

Currently, there are two places to look for tutorials:

* The README of our `GitHub repository <https://github.com/tumaer/jax-sph>`_.
* The `notebooks <https://github.com/tumaer/jax-sph/tree/main/notebooks>`_ in the same
repository.
2 changes: 1 addition & 1 deletion jax_sph/case_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax_md import space

from jax_sph.eos import RIEMANNEoS, TaitEoS
from jax_sph.io_state import read_h5
from jax_sph.jax_md import space
from jax_sph.utils import (
Tag,
get_noise_masked,
Expand Down
64 changes: 32 additions & 32 deletions jax_sph/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
### global and hardware-related configs

# .yaml case configuration file
cfg.config = None # previously: case
cfg.config = None
# Seed for random number generator
cfg.seed = 123
# Whether to disable jitting compilation
cfg.no_jit = False
# Which GPU to use. -1 for CPU
cfg.gpu = 0
# Data type. One of "float32" or "float64"
cfg.dtype = "float64" # previously: no_f64
cfg.dtype = "float64"
# XLA memory fraction to be preallocated. The JAX default is 0.75.
# Should be specified before importing the library.
cfg.xla_mem_fraction = 0.75
Expand All @@ -30,30 +30,30 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
# Simulation mode. One of "sim" (run simulation) or "rlx" (run relaxation)
cfg.case.mode = "sim"
# Dimension of the simulation. One of 2 or 3
cfg.case.dim = 3 # previously: dim
cfg.case.dim = 3
# Average distance between particles [0.001, 0.1]
cfg.case.dx = 0.05 # previously: dx
cfg.case.dx = 0.05
# Initial state h5 path. Overrides `r0_type`. Can be useful to restart a simulation.
cfg.case.state0_path = None # previously: state0-path
cfg.case.state0_path = None
# Which properties to adopt from state0_path. Include all to restart a simulation.
cfg.case.state0_keys = ["r"]
# Position initialization type. One of "cartesian" or "relaxed". Cartesian can have
# `r0_noise_factor` and relaxed requires a state to be present in `data_relaxed`.
cfg.case.r0_type = "cartesian" # previously: r0-type
cfg.case.r0_type = "cartesian"
# How much Gaussian noise to add to r0. ( _ * dx)
cfg.case.r0_noise_factor = 0.0 # previously: r0-noise-factor
cfg.case.r0_noise_factor = 0.0
# Magnitude of external force field
cfg.case.g_ext_magnitude = 0.0 # previously: g-ext-magnitude
cfg.case.g_ext_magnitude = 0.0
# Reference dynamic viscosity. Inversely proportional to Re.
cfg.case.viscosity = 0.01 # previously: viscosity
cfg.case.viscosity = 0.01
# Estimate max flow velocity to calculate artificial speed of sound.
cfg.case.u_ref = 1.0 # previously: u_ref
cfg.case.u_ref = 1.0
# Reference speed of sound factor w.r.t. u_ref.
cfg.case.c_ref_factor = 10.0 # previously: p-bg-factor
cfg.case.c_ref_factor = 10.0
# Reference density
cfg.case.rho_ref = 1.0
# Reference temperature
cfg.case.T_ref = 1.0 # previously: T-ref
cfg.case.T_ref = 1.0
# Reference thermal conductivity
cfg.case.kappa_ref = 0.0
# Reference heat capacity at constant pressure
Expand All @@ -65,29 +65,29 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
cfg.solver = OmegaConf.create({})

# Solver name. One of "SPH" (standard SPH) or "RIE" (Riemann SPH)
cfg.solver.name = "SPH" # previously: solver
cfg.solver.name = "SPH"
# Transport velocity inclusion factor [0,...,1]
cfg.solver.tvf = 0.0 # previously: tvf
cfg.solver.tvf = 0.0
# CFL condition factor
cfg.solver.cfl = 0.25 # previously: cfl
cfg.solver.cfl = 0.25
# Density evolution vs density summation
cfg.solver.density_evolution = False # previously: density-evolution
cfg.solver.density_evolution = False
# Density renormalization when density evolution
cfg.solver.density_renormalize = False # previously: density-renormalize
cfg.solver.density_renormalize = False
# Integration time step. If None, it is calculated from the CFL condition.
cfg.solver.dt = None # previously: dt
cfg.solver.dt = None
# Physical time length of simulation
cfg.solver.t_end = 0.2 # previously: t-end
cfg.solver.t_end = 0.2
# Parameter alpha of artificial viscosity term
cfg.solver.artificial_alpha = 0.0 # previously: artificial-alpha
cfg.solver.artificial_alpha = 0.0
# Whether to turn on free-slip boundary condition
cfg.solver.free_slip = False # previously: free-slip
cfg.solver.free_slip = False
# Riemann dissipation limiter parameter, -1 = off
cfg.solver.eta_limiter = 3 # previously: eta-limiter
cfg.solver.eta_limiter = 3
# Thermal conductivity (non-dimensional)
cfg.solver.kappa = 0 # previously: kappa
cfg.solver.kappa = 0
# Whether to apply the heat conduction term
cfg.solver.heat_conduction = False # previously: heat-conduction
cfg.solver.heat_conduction = False
# Whether to apply boundaty conditions
cfg.solver.is_bc_trick = False # new

Expand All @@ -102,37 +102,37 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
# "WC6K" (Wendland C4 kernel)
# "GK" (gaussian kernel)
# "SGK" (super gaussian kernel)
cfg.kernel.name = "QSK" # previously: kernel
cfg.kernel.name = "QSK"
# Smoothing length factor
cfg.kernel.h_factor = 1.0 # new. Should default to 1.3 WC2K and 1.0 QSK

### equation of state
cfg.eos = OmegaConf.create({})

# EoS name. One of "Tait" or "RIEMANN"
cfg.eos.name = "Tait" # previously: eos
cfg.eos.name = "Tait"
# power in the Tait equation of state
cfg.eos.gamma = 1.0
# background pressure factor w.r.t. p_ref
cfg.eos.p_bg_factor = 0.0 # previously: p-bg-factor
cfg.eos.p_bg_factor = 0.0

### neighbor list
cfg.nl = OmegaConf.create({})

# Neighbor list backend. One of "jaxmd_vmap", "jaxmd_scan", "matscipy"
cfg.nl.backend = "jaxmd_vmap" # previously: nl-backend
cfg.nl.backend = "jaxmd_vmap"
# Number of partitions for neighbor list. Applies to jaxmd_scan only.
cfg.nl.num_partitions = 1 # previously: num-partitions
cfg.nl.num_partitions = 1

### output writing
cfg.io = OmegaConf.create({})

# In which format to write states. A subset of ["h5", "vtk"]
cfg.io.write_type = [] # previously: write-h5, write-vtk
cfg.io.write_type = []
# Every `write_every` step will be saved
cfg.io.write_every = 1 # previously: write-every
cfg.io.write_every = 1
# Where to write and read data
cfg.io.data_path = "./" # previously: data-path
cfg.io.data_path = "./"
# What to print to stdout. As list of possible properties.
cfg.io.print_props = ["Ekin", "u_max"]

Expand Down
Loading
Loading