diff --git a/README.md b/README.md index 1e42385..b47464e 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Later, you just need to `source venv/bin/activate` to activate the environment. If you want to use a CUDA GPU, you first need a running Nvidia driver. And then just follow the instructions [here](https://jax.readthedocs.io/en/latest/installation.html). The whole process could look like this: ```bash source .venv/bin/activate -pip install -U "jax[cuda12]" +pip install -U "jax[cuda12]==0.4.29" ``` ## Getting Started diff --git a/cases/cw.py b/cases/cw.py index ae9ea86..5190470 100644 --- a/cases/cw.py +++ b/cases/cw.py @@ -5,7 +5,13 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag, pos_box_2d, pos_init_cartesian_2d +from jax_sph.utils import ( + Tag, + pos_box_2d, + pos_box_3d, + pos_init_cartesian_2d, + pos_init_cartesian_3d, +) class CW(SimulationSetup): @@ -14,40 +20,67 @@ class CW(SimulationSetup): def __init__(self, cfg: DictConfig): super().__init__(cfg) + # define offset vector + self.offset_vec = np.ones(cfg.case.dim) * cfg.solver.n_walls * cfg.case.dx + # relaxation configurations if cfg.case.mode == "rlx" or cfg.case.r0_type == "relaxed": raise NotImplementedError("Relaxation not implemented for CW.") - def _box_size2D(self): - return np.array([self.special.L_wall, self.special.H_wall]) + 6 * self.case.dx + def _box_size2D(self, n_walls): + sp = self.special + return np.array([sp.L_wall, sp.H_wall]) + 2 * n_walls * self.case.dx - def _box_size3D(self): - dx6 = 6 * self.case.dx - return np.array([self.special.L_wall + dx6, self.special.H_wall + dx6, 0.5]) + def _box_size3D(self, n_walls): + sp = self.special + dx2n = 2 * n_walls * self.case.dx + return np.array([sp.L_wall + dx2n, sp.H_wall + dx2n, 1.0 + dx2n]) - def _init_pos2D(self, box_size, dx): - dx3 = 3 * self.case.dx - walls = pos_box_2d(self.special.L_wall, self.special.H_wall, dx) + def _init_walls_2d(self, dx, n_walls): + sp = self.special + rw = pos_box_2d(np.array([sp.L_wall, sp.H_wall]), dx, n_walls) + return rw - r_fluid = pos_init_cartesian_2d(np.array([self.special.L, self.special.H]), dx) - r_fluid += dx3 + np.array(self.special.cube_offset) - res = np.concatenate([walls, r_fluid]) - return res + def _init_walls_3d(self, dx, n_walls): + sp = self.special + rw = pos_box_3d(np.array([sp.L_wall, sp.H_wall, 1.0]), dx, n_walls, False) + return rw + + def _init_pos2D(self, box_size, dx, n_walls): + dxn = n_walls * self.case.dx + + # initialize walls + r_w = self._init_walls_2d(dx, n_walls) + + # initialize fluid phase + r_f = pos_init_cartesian_2d(np.array([self.special.L, self.special.H]), dx) + r_f += dxn + np.array(self.special.cube_offset) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + return r, tag + + def _init_pos3D(self, box_size, dx, n_walls): + dxn = n_walls * self.case.dx + + # initialize walls + r_w = self._init_walls_3d(dx, n_walls) - def _tag2D(self, r): - dx3 = 3 * self.case.dx - mask_left = jnp.where(r[:, 0] < dx3, True, False) - mask_bottom = jnp.where(r[:, 1] < dx3, True, False) - mask_right = jnp.where(r[:, 0] > self.special.L_wall + dx3, True, False) - mask_top = jnp.where(r[:, 1] > self.special.H_wall + dx3, True, False) - mask_wall = mask_left + mask_bottom + mask_right + mask_top + # initialize fluid phase + r_f = pos_init_cartesian_3d(np.array([self.special.L, self.special.H, 0.3]), dx) + r_f += dxn + np.array(self.special.cube_offset) - tag = jnp.full(len(r), Tag.FLUID, dtype=int) - tag = jnp.where(mask_wall, Tag.SOLID_WALL, tag) - return tag + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) - def _tag3D(self, r): - return self._tag2D(r) + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + return r, tag def _init_velocity2D(self, r): res = jnp.array(self.special.u_init) diff --git a/cases/db.py b/cases/db.py index 908e738..b9e4b9b 100644 --- a/cases/db.py +++ b/cases/db.py @@ -5,7 +5,13 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag, pos_box_2d, pos_init_cartesian_2d +from jax_sph.utils import ( + Tag, + pos_box_2d, + pos_box_3d, + pos_init_cartesian_2d, + pos_init_cartesian_3d, +) class DB(SimulationSetup): @@ -24,6 +30,9 @@ def __init__(self, cfg: DictConfig): # | --------------------------| # < L_wall > + # define offset vector + self.offset_vec = self._offset_vec() + # relaxation configurations if self.case.mode == "rlx": self.special.L_wall = self.special.L @@ -33,55 +42,83 @@ def __init__(self, cfg: DictConfig): if self.case.r0_type == "relaxed": self._load_only_fluid = True - def _box_size2D(self): + def _box_size2D(self, n_walls): dx, bo = self.case.dx, self.special.box_offset return np.array( - [self.special.L_wall + 6 * dx + bo, self.special.H_wall + 6 * dx + bo] + [ + self.special.L_wall + 2 * n_walls * dx + bo, + self.special.H_wall + 2 * n_walls * dx + bo, + ] ) - def _box_size3D(self): + def _box_size3D(self, n_walls): dx, bo = self.case.dx, self.box_offset sp = self.special - return np.array([sp.L_wall + 6 * dx + bo, sp.H_wall + 6 * dx + bo, sp.W]) + return np.array( + [sp.L_wall + 2 * n_walls * dx + bo, sp.H_wall + 2 * n_walls * dx + bo, sp.W] + ) + + def _init_walls_2d(self, dx, n_walls): + sp = self.special + rw = pos_box_2d(np.array([sp.L_wall, sp.H_wall]), dx, n_walls) + return rw + + def _init_walls_3d(self, dx, n_walls): + sp = self.special + rw = pos_box_3d(np.array([sp.L_wall, sp.H_wall, 1.0]), dx, n_walls) + return rw - def _init_pos2D(self, box_size, dx): + def _init_pos2D(self, box_size, dx, n_walls): sp = self.special + + # initialize fluid phase if self.case.r0_type == "cartesian": - r_fluid = 3 * dx + pos_init_cartesian_2d(np.array([sp.L, sp.H]), dx) + r_f = n_walls * dx + pos_init_cartesian_2d(np.array([sp.L, sp.H]), dx) else: - r_fluid = self._get_relaxed_r0(None, dx) + r_f = self._get_relaxed_r0(None, dx) - walls = pos_box_2d(sp.L_wall, sp.H_wall, dx) - res = np.concatenate([walls, r_fluid]) - return res + # initialize walls + r_w = self._init_walls_2d(dx, n_walls) - def _init_pos3D(self, box_size, dx): - # cartesian coordinates in z - Lz = box_size[2] - zs = np.arange(0, Lz, dx) + 0.5 * dx + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) - # extend 2D points to 3D - xy = self._init_pos2D(box_size, dx) - xy_ext = np.hstack([xy, np.ones((len(xy), 1))]) + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + return r, tag - r_xyz = np.vstack([xy_ext * [1, 1, z] for z in zs]) - return r_xyz + def _init_pos3D(self, box_size, dx, n_walls): + # TODO: not validated yet + sp = self.special - def _tag2D(self, r): - dx3 = 3 * self.case.dx - mask_left = jnp.where(r[:, 0] < dx3, True, False) - mask_bottom = jnp.where(r[:, 1] < dx3, True, False) - mask_right = jnp.where(r[:, 0] > self.special.L_wall + dx3, True, False) - mask_top = jnp.where(r[:, 1] > self.special.H_wall + dx3, True, False) + # initialize fluid phase + if self.case.r0_type == "cartesian": + r_f = np.array([1.0, 1.0, 0.0]) * n_walls * dx + pos_init_cartesian_3d( + np.array([sp.L, sp.H, 1.0]), dx + ) + else: + r_f = self._get_relaxed_r0(None, dx) - mask_wall = mask_left + mask_bottom + mask_right + mask_top + # initialize walls + r_w = self._init_walls_3d(dx, n_walls) - tag = jnp.full(len(r), Tag.FLUID, dtype=int) - tag = jnp.where(mask_wall, Tag.SOLID_WALL, tag) - return tag + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) - def _tag3D(self, r): - return self._tag2D(r) + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + return r, tag + + def _offset_vec(self): + dim = self.cfg.case.dim + if dim == 2: + res = np.ones(dim) * self.cfg.solver.n_walls * self.cfg.case.dx + elif dim == 3: + res = np.array([1.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + return res def _init_velocity2D(self, r): return jnp.zeros_like(r) diff --git a/cases/ht.py b/cases/ht.py index 73262e6..ea17db7 100644 --- a/cases/ht.py +++ b/cases/ht.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag +from jax_sph.utils import Tag, pos_init_cartesian_2d, pos_init_cartesian_3d class HT(SimulationSetup): @@ -15,6 +15,9 @@ class HT(SimulationSetup): def __init__(self, cfg: DictConfig): super().__init__(cfg) + # define offset vector + self.offset_vec = self._offset_vec() + # relaxation configurations if self.case.mode == "rlx": self._set_default_rlx() @@ -24,33 +27,113 @@ def __init__(self, cfg: DictConfig): self._init_pos2D = self._get_relaxed_r0 self._init_pos3D = self._get_relaxed_r0 - def _box_size2D(self): - dx = self.case.dx - return np.array([1, 0.2 + 6 * dx]) + def _box_size2D(self, n_walls): + dx2n = self.case.dx * n_walls * 2 + sp = self.special + return np.array([sp.L, sp.H + dx2n]) + + def _box_size3D(self, n_walls): + dx2n = self.case.dx * n_walls * 2 + sp = self.special + return np.array([sp.L, sp.H + dx2n, 0.5]) + + def _init_walls_2d(self, dx, n_walls): + sp = self.special + + # thickness of wall particles + dxn = dx * n_walls + + # horizontal and vertical blocks + horiz = pos_init_cartesian_2d(np.array([sp.L, dxn]), dx) + + # wall: bottom, top + wall_b = horiz.copy() + wall_t = horiz.copy() + np.array([0.0, sp.H + dxn]) + + rw = np.concatenate([wall_b, wall_t]) + return rw + + def _init_walls_3d(self, dx, n_walls): + sp = self.special - def _box_size3D(self): - dx = self.case.dx - return np.array([1, 0.2 + 6 * dx, 0.5]) + # thickness of wall particles + dxn = dx * n_walls - def _tag2D(self, r): - dx3 = 3 * self.case.dx - _box_size = self._box_size2D() - tag = jnp.full(len(r), Tag.FLUID, dtype=int) + # horizontal and vertical blocks + horiz = pos_init_cartesian_3d(np.array([sp.L, dxn, 0.5]), dx) - mask_no_slip_wall = (r[:, 1] < dx3) + ( - r[:, 1] > (_box_size[1] - 6 * self.case.dx) + dx3 + # wall: bottom, top + wall_b = horiz.copy() + wall_t = horiz.copy() + np.array([0.0, sp.H + dxn, 0.0]) + + rw = np.concatenate([wall_b, wall_t]) + return rw + + def _init_pos2D(self, box_size, dx, n_walls): + sp = self.special + + # initialize fluid phase + r_f = np.array([0.0, 1.0]) * n_walls * dx + pos_init_cartesian_2d( + np.array([sp.L, sp.H]), dx ) + + # initialize walls + r_w = self._init_walls_2d(dx, n_walls) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + # set thermal tags + _box_size = self._box_size2D(n_walls) mask_hot_wall = ( - (r[:, 1] < dx3) + (r[:, 1] < dx * n_walls) * (r[:, 0] < (_box_size[0] / 2) + self.special.hot_wall_half_width) * (r[:, 0] > (_box_size[0] / 2) - self.special.hot_wall_half_width) ) - tag = jnp.where(mask_no_slip_wall, Tag.SOLID_WALL, tag) tag = jnp.where(mask_hot_wall, Tag.DIRICHLET_WALL, tag) - return tag - def _tag3D(self, r): - return self._tag2D(r) + return r, tag + + def _init_pos3D(self, box_size, dx, n_walls): + sp = self.special + + # initialize fluid phase + r_f = np.array([0.0, 1.0, 0.0]) * n_walls * dx + pos_init_cartesian_3d( + np.array([sp.L, sp.H, 0.5]), dx + ) + + # initialize walls + r_w = self._init_walls_3d(dx, n_walls) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + # set thermal tags + _box_size = self._box_size3D(n_walls) + mask_hot_wall = ( + (r[:, 1] < dx * n_walls) + * (r[:, 0] < (_box_size[0] / 2) + self.special.hot_wall_half_width) + * (r[:, 0] > (_box_size[0] / 2) - self.special.hot_wall_half_width) + ) + tag = jnp.where(mask_hot_wall, Tag.DIRICHLET_WALL, tag) + + return r, tag + + def _offset_vec(self): + dim = self.cfg.case.dim + if dim == 2: + res = np.array([0.0, 1.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + elif dim == 3: + res = np.array([0.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + return res def _init_velocity2D(self, r): return jnp.zeros_like(r) @@ -59,20 +142,22 @@ def _init_velocity3D(self, r): return jnp.zeros_like(r) def _external_acceleration_fn(self, r): - dx3 = 3 * self.case.dx + n_walls = self.cfg.solver.n_walls + dxn = n_walls * self.case.dx res = jnp.zeros_like(r) x_force = jnp.ones((len(r))) - box_size = self._box_size2D() - fluid_mask = (r[:, 1] < box_size[1] - dx3) * (r[:, 1] > dx3) + box_size = self._box_size2D(n_walls) + fluid_mask = (r[:, 1] < box_size[1] - dxn) * (r[:, 1] > dxn) x_force = jnp.where(fluid_mask, x_force, 0) res = res.at[:, 0].set(x_force) return res * self.case.g_ext_magnitude def _boundary_conditions_fn(self, state): + n_walls = self.cfg.solver.n_walls mask_fluid = state["tag"] == Tag.FLUID # set incoming fluid temperature to reference_temperature - mask_inflow = mask_fluid * (state["r"][:, 0] < 3 * self.case.dx) + mask_inflow = mask_fluid * (state["r"][:, 0] < n_walls * self.case.dx) state["T"] = jnp.where(mask_inflow, self.case.T_ref, state["T"]) state["dTdt"] = jnp.where(mask_inflow, 0.0, state["dTdt"]) @@ -96,7 +181,7 @@ def _boundary_conditions_fn(self, state): # set outlet temperature gradients to zero to avoid interaction with inflow # bounds[0][1] is the x-coordinate of the outlet mask_outflow = mask_fluid * ( - state["r"][:, 0] > self.case.bounds[0][1] - 3 * self.case.dx + state["r"][:, 0] > self.case.bounds[0][1] - n_walls * self.case.dx ) state["dTdt"] = jnp.where(mask_outflow, 0.0, state["dTdt"]) diff --git a/cases/ht.yaml b/cases/ht.yaml index 808b6bf..eba2503 100644 --- a/cases/ht.yaml +++ b/cases/ht.yaml @@ -11,6 +11,8 @@ case: special: hot_wall_temperature: 1.23 # nondimensionalized corresponding to 100 hot_wall_half_width: 0.25 + L: 1.0 # water column length + H: 0.2 # water column height solver: diff --git a/cases/ldc.py b/cases/ldc.py index 220d235..478f44b 100644 --- a/cases/ldc.py +++ b/cases/ldc.py @@ -6,7 +6,13 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag +from jax_sph.utils import ( + Tag, + pos_box_2d, + pos_box_3d, + pos_init_cartesian_2d, + pos_init_cartesian_3d, +) class LDC(SimulationSetup): @@ -21,6 +27,9 @@ def __init__(self, cfg: DictConfig): elif self.case.dim == 3: self.u_lid = jnp.array([self.special.u_x_lid, 0.0, 0.0]) + # define offset vector + self.offset_vec = self._offset_vec() + # relaxation configurations if self.case.mode == "rlx": self._set_default_rlx() @@ -37,19 +46,68 @@ def _box_size3D(self): dx6 = 6 * self.case.dx return np.array([1 + dx6, 1 + dx6, 0.5]) - def _tag2D(self, r): - box_size = self._box_size2D() if self.case.dim == 2 else self._box_size3D() + def _box_size2D(self, n_walls): + return np.ones((2,)) + 2 * n_walls * self.case.dx + + def _box_size3D(self, n_walls): + dx2n = 2 * n_walls * self.case.dx + return np.array([1 + dx2n, 1 + dx2n, 0.5]) + + def _init_walls_2d(self, dx, n_walls): + rw = pos_box_2d(np.ones(2), dx, n_walls) + return rw + + def _init_walls_3d(self, dx, n_walls): + rw = pos_box_3d(np.array([1.0, 1.0, 0.5]), dx, n_walls) + return rw + + def _init_pos2D(self, box_size, dx, n_walls): + # initialize fluid phase + r_f = n_walls * dx + pos_init_cartesian_2d(np.ones(2), dx) + + # initialize walls + r_w = self._init_walls_2d(dx, n_walls) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) - mask_lid = r[:, 1] > (box_size[1] - 3 * self.case.dx) - r_centered_abs = jnp.abs(r - r.mean(axis=0)) - mask_water = jnp.where(r_centered_abs.max(axis=1) < 0.5, True, False) - tag = jnp.full(len(r), Tag.SOLID_WALL, dtype=int) - tag = jnp.where(mask_water, Tag.FLUID, tag) + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + # set velocity wall tag + box_size = self._box_size2D(n_walls) + mask_lid = r[:, 1] > (box_size[1] - n_walls * self.case.dx) tag = jnp.where(mask_lid, Tag.MOVING_WALL, tag) - return tag + return r, tag + + def _init_pos3D(self, box_size, dx, n_walls): + # initialize fluid phase + r_f = n_walls * dx + pos_init_cartesian_3d(np.array([1.0, 1.0, 0.5]), dx) + + # initialize walls + r_w = self._init_walls_3d(dx, n_walls) - def _tag3D(self, r): - return self._tag2D(r) + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + # set velocity wall tag + box_size = self._box_size3D(n_walls) + mask_lid = r[:, 1] > (box_size[1] - n_walls * self.case.dx) + tag = jnp.where(mask_lid, Tag.MOVING_WALL, tag) + return r, tag + + def _offset_vec(self): + dim = self.cfg.case.dim + if dim == 2: + res = jnp.ones(dim) * self.cfg.solver.n_walls * self.case.dx + elif dim == 3: + res = jnp.array([1.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.case.dx + return res def _init_velocity2D(self, r): u = jnp.zeros_like(r) diff --git a/cases/pf.py b/cases/pf.py index 07e347d..1f2a9d3 100644 --- a/cases/pf.py +++ b/cases/pf.py @@ -5,7 +5,7 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag +from jax_sph.utils import Tag, pos_init_cartesian_2d, pos_init_cartesian_3d class PF(SimulationSetup): @@ -17,6 +17,9 @@ class PF(SimulationSetup): def __init__(self, cfg: DictConfig): super().__init__(cfg) + # define offset vector + self.offset_vec = self._offset_vec() + # relaxation configurations if self.case.mode == "rlx": self._set_default_rlx() @@ -26,23 +29,95 @@ def __init__(self, cfg: DictConfig): self._init_pos2D = self._get_relaxed_r0 self._init_pos3D = self._get_relaxed_r0 - def _box_size2D(self): - return np.array([0.4, 1 + 6 * self.case.dx]) + def _box_size2D(self, n_walls): + dx2n = self.case.dx * n_walls * 2 + sp = self.special + return np.array([sp.L, sp.H + dx2n]) + + def _box_size3D(self, n_walls): + dx2n = self.case.dx * n_walls * 2 + sp = self.special + return np.array([sp.L, sp.H + dx2n, 0.4]) + + def _init_walls_2d(self, dx, n_walls): + sp = self.special + + # thickness of wall particles + dxn = dx * n_walls + + # horizontal and vertical blocks + horiz = pos_init_cartesian_2d(np.array([sp.L, dxn]), dx) + + # wall: bottom, top + wall_b = horiz.copy() + wall_t = horiz.copy() + np.array([0.0, sp.H + dxn]) + + rw = np.concatenate([wall_b, wall_t]) + return rw + + def _init_walls_3d(self, dx, n_walls): + sp = self.special + + # thickness of wall particles + dxn = dx * n_walls + + # horizontal and vertical blocks + horiz = pos_init_cartesian_3d(np.array([sp.L, dxn, 0.4]), dx) + + # wall: bottom, top + wall_b = horiz.copy() + wall_t = horiz.copy() + np.array([0.0, sp.H + dxn, 0.0]) + + rw = np.concatenate([wall_b, wall_t]) + return rw + + def _init_pos2D(self, box_size, dx, n_walls): + sp = self.special + + # initialize fluid phase + r_f = np.array([0.0, 1.0]) * n_walls * dx + pos_init_cartesian_2d( + np.array([sp.L, sp.H]), dx + ) + + # initialize walls + r_w = self._init_walls_2d(dx, n_walls) + + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) + + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) + + return r, tag + + def _init_pos3D(self, box_size, dx, n_walls): + sp = self.special + + # initialize fluid phase + r_f = np.array([0.0, 1.0, 0.0]) * n_walls * dx + pos_init_cartesian_3d( + np.array([sp.L, sp.H, 0.4]), dx + ) + + # initialize walls + r_w = self._init_walls_3d(dx, n_walls) - def _box_size3D(self): - return np.array([0.4, 1 + 6 * self.case.dx, 0.4]) + # set tags + tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int) + tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int) - def _tag2D(self, r): - dx3 = 3 * self.case.dx - box_size = self._box_size2D() - tag = jnp.full(len(r), Tag.FLUID, dtype=int) + r = np.concatenate([r_w, r_f]) + tag = np.concatenate([tag_w, tag_f]) - mask_wall = (r[:, 1] < dx3) + (r[:, 1] > box_size[1] - dx3) - tag = jnp.where(mask_wall, Tag.SOLID_WALL, tag) - return tag + return r, tag - def _tag3D(self, r): - return self._tag2D(r) + def _offset_vec(self): + dim = self.cfg.case.dim + if dim == 2: + res = np.array([0.0, 1.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + elif dim == 3: + res = np.array([0.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx + return res def _init_velocity2D(self, r): return jnp.zeros_like(r) @@ -51,11 +126,12 @@ def _init_velocity3D(self, r): return jnp.zeros_like(r) def _external_acceleration_fn(self, r): - dx3 = 3 * self.case.dx + n_walls = self.cfg.solver.n_walls + dxn = n_walls * self.case.dx res = jnp.zeros_like(r) x_force = jnp.ones((len(r))) - box_size = self._box_size2D() - fluid_mask = (r[:, 1] < box_size[1] - dx3) * (r[:, 1] > dx3) + box_size = self._box_size2D(n_walls) + fluid_mask = (r[:, 1] < box_size[1] - dxn) * (r[:, 1] > dxn) x_force = jnp.where(fluid_mask, x_force, 0) res = res.at[:, 0].set(x_force) return res * self.case.g_ext_magnitude diff --git a/cases/pf.yaml b/cases/pf.yaml index cca1cb9..0bd5557 100644 --- a/cases/pf.yaml +++ b/cases/pf.yaml @@ -9,6 +9,9 @@ case: viscosity: 100.0 u_ref: 1.25 g_ext_magnitude: 1000.0 + special: + L: 0.4 # water column length + H: 1.0 # water column height solver: dt: 0.0000005 diff --git a/cases/rpf.py b/cases/rpf.py index 5218cf0..0e08b3e 100644 --- a/cases/rpf.py +++ b/cases/rpf.py @@ -5,7 +5,6 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag class RPF(SimulationSetup): @@ -23,18 +22,17 @@ def __init__(self, cfg: DictConfig): self._init_pos2D = self._get_relaxed_r0 self._init_pos3D = self._get_relaxed_r0 - def _box_size2D(self): + def _box_size2D(self, n_walls): return np.array([1.0, 2.0]) - def _box_size3D(self): + def _box_size3D(self, n_walls): return np.array([1.0, 2.0, 0.5]) - def _tag2D(self, r): - tag = jnp.full(len(r), Tag.FLUID, dtype=int) - return tag + def _init_walls_2d(self): + pass - def _tag3D(self, r): - return self._tag2D(r) + def _init_walls_3d(self): + pass def _init_velocity2D(self, r): u = jnp.zeros_like(r) diff --git a/cases/tgv.py b/cases/tgv.py index 6430191..80f008c 100644 --- a/cases/tgv.py +++ b/cases/tgv.py @@ -6,7 +6,6 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import Tag class TGV(SimulationSetup): @@ -24,18 +23,17 @@ def __init__(self, cfg: DictConfig): self._init_pos2D = self._get_relaxed_r0 self._init_pos3D = self._get_relaxed_r0 - def _box_size2D(self): + def _box_size2D(self, n_walls): return np.array([1.0, 1.0]) - def _box_size3D(self): + def _box_size3D(self, n_walls): return 2 * np.pi * np.array([1.0, 1.0, 1.0]) - def _tag2D(self, r): - tag = jnp.full(len(r), Tag.FLUID, dtype=int) - return tag + def _init_walls_2d(self): + pass - def _tag3D(self, r): - return self._tag2D(r) + def _init_walls_3d(self): + pass def _init_velocity2D(self, r): x, y = r diff --git a/cases/ut.py b/cases/ut.py index 9a4d42f..c5f9c21 100644 --- a/cases/ut.py +++ b/cases/ut.py @@ -5,10 +5,10 @@ from omegaconf import DictConfig from jax_sph.case_setup import SimulationSetup -from jax_sph.utils import pos_init_cartesian_2d, pos_init_cartesian_3d +from jax_sph.utils import Tag, pos_init_cartesian_2d, pos_init_cartesian_3d -class UTSetup(SimulationSetup): +class UT(SimulationSetup): """Unit Test: cube of water in periodic boundary box""" def __init__(self, cfg: DictConfig): @@ -20,25 +20,29 @@ def __init__(self, cfg: DictConfig): if self.case.mode == "rlx" or self.case.r0_type == "relaxed": raise NotImplementedError("Relaxation not implemented for CW") - def _box_size2D(self): + def _box_size2D(self, n_walls): return np.array([self.special.L_wall, self.special.H_wall]) - def _box_size3D(self): + def _box_size3D(self, n_walls): return np.array([self.special.L_wall, self.special.L_wall, self.special.L_wall]) - def _init_pos2D(self, box_size, dx): - cube = np.array([self.special.L, self.special.H]) - return self.cube_offset + pos_init_cartesian_2d(cube, dx) + def _init_walls_2d(self): + pass - def _init_pos3D(self, box_size, dx): - cube = np.array([self.special.L, self.special.L, self.special.H]) - return self.cube_offset + pos_init_cartesian_3d(cube, dx) + def _init_walls_3d(self): + pass - def _tag2D(self, r): - return jnp.zeros(len(r), dtype=int) + def _init_pos2D(self, box_size, dx, n_walls): + cube = np.array([self.special.L, self.special.H]) + r = self.cube_offset + pos_init_cartesian_2d(cube, dx) + tag = jnp.full(len(r), Tag.FLUID, dtype=int) + return r, tag - def _tag3D(self, r): - return self._tag2D(r) + def _init_pos3D(self, box_size, dx, n_walls): + cube = np.array([self.special.L, self.special.L, self.special.H]) + r = self.cube_offset + pos_init_cartesian_3d(cube, dx) + tag = jnp.full(len(r), Tag.FLUID, dtype=int) + return r, tag def _init_velocity2D(self, r): return jnp.zeros_like(r) diff --git a/docs/requirements.txt b/docs/requirements.txt index 186f921..26f1e2c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,5 @@ h5py --e git+https://github.com/jax-md/jax-md.git@c451353f6ddcab031f660befda256d8a4f657855#egg=jax-md -jax[cpu]==0.4.28 +jax[cpu]==0.4.29 omegaconf pandas pyvista diff --git a/jax_sph/case_setup.py b/jax_sph/case_setup.py index d40cac2..61383e6 100644 --- a/jax_sph/case_setup.py +++ b/jax_sph/case_setup.py @@ -15,6 +15,8 @@ from jax_sph.jax_md import space from jax_sph.utils import ( Tag, + compute_nws_jax_wrapper, + compute_nws_scipy, get_noise_masked, pos_init_cartesian_2d, pos_init_cartesian_3d, @@ -59,6 +61,7 @@ def initialize(self): - state (dict): dictionary containing all field values - g_ext_fn (Callable): external force function - bc_fn (Callable): boundary conditions function (e.g. velocity at walls) + - nw_fn (Callable): jit-able wall normal funct. when moving walls, else None - eos (Callable): equation of state function - key (PRNGKey): random key for sampling - displacement_fn (Callable): displacement function for edge features @@ -122,13 +125,11 @@ def initialize(self): # initialize box and positions of particles if dim == 2: - box_size = self._box_size2D() - r = self._init_pos2D(box_size, dx) - tag = self._tag2D(r) + box_size = self._box_size2D(cfg.solver.n_walls) + r, tag = self._init_pos2D(box_size, dx, cfg.solver.n_walls) elif dim == 3: - box_size = self._box_size3D() - r = self._init_pos3D(box_size, dx) - tag = self._tag3D(r) + box_size = self._box_size3D(cfg.solver.n_walls) + r, tag = self._init_pos3D(box_size, dx, cfg.solver.n_walls) displacement_fn, shift_fn = space.periodic(side=box_size) num_particles = len(r) @@ -152,6 +153,10 @@ def initialize(self): rho, mass, eta, temperature, kappa, Cp = self._set_field_properties( num_particles, mass_ref, cfg.case ) + # whether to compute wall normals + is_nw = cfg.solver.free_slip or cfg.solver.name == "RIE" + # calculate wall normals if necessary + nw = self._compute_wall_normals("scipy")(r, tag) if is_nw else jnp.zeros_like(r) # initialize the state dictionary state = { @@ -170,6 +175,7 @@ def initialize(self): "T": temperature, "kappa": kappa, "Cp": Cp, + "nw": nw, } # overwrite the state dictionary with the provided one @@ -196,12 +202,25 @@ def initialize(self): g_ext_fn = self._external_acceleration_fn bc_fn = self._boundary_conditions_fn + # whether to recompute the wall normals at every integration step + is_nw_recompute = (tag == Tag.MOVING_WALL).any() and is_nw + if is_nw_recompute: + assert cfg.nl.backend != "matscipy", NotImplementedError( + "Wall normals not yet implemented for matscipy neighbor list when " + "working with moving boundaries. \nIf you work with moving boundaries, " + "don't use one of: `nl.backend=matscipy` or `solver.free_slip=True` or " + "`solver.name=RIE`." + ) + kwargs = {"disp_fn": displacement_fn, "box_size": box_size, "state0": state} + nw_fn = self._compute_wall_normals("jax", **kwargs) if is_nw_recompute else None + return ( cfg, box_size, state, g_ext_fn, bc_fn, + nw_fn, eos, key, displacement_fn, @@ -209,25 +228,31 @@ def initialize(self): ) @abstractmethod - def _box_size2D(self, cfg): + def _box_size2D(self, n_walls): pass @abstractmethod - def _box_size3D(self, cfg): + def _box_size3D(self, n_walls): pass - def _init_pos2D(self, box_size, dx): - return pos_init_cartesian_2d(box_size, dx) + def _init_pos2D(self, box_size, dx, n_walls): + r = pos_init_cartesian_2d(box_size, dx) + tag = jnp.full(len(r), Tag.FLUID, dtype=int) + return r, tag - def _init_pos3D(self, box_size, dx): - return pos_init_cartesian_3d(box_size, dx) + def _init_pos3D(self, box_size, dx, n_walls): + r = pos_init_cartesian_3d(box_size, dx) + tag = jnp.full(len(r), Tag.FLUID, dtype=int) + return r, tag @abstractmethod - def _tag2D(self, r): + def _init_walls_2d(self): + """Create all solid walls of a 2D case.""" pass @abstractmethod - def _tag3D(self, r): + def _init_walls_3d(self): + """Create all solid walls of a 3D case.""" pass @abstractmethod @@ -266,7 +291,7 @@ def _get_relaxed_r0(self, box_size, dx): if self._load_only_fluid: return state["r"][state["tag"] == Tag.FLUID] else: - return state["r"] + return state["r"], state["tag"] def _set_field_properties(self, num_particles, mass_ref, case): rho = jnp.ones(num_particles) * case.rho_ref @@ -287,8 +312,42 @@ def _set_default_rlx(self): self._box_size3D_rlx = self._box_size3D self._init_pos2D_rlx = self._init_pos2D self._init_pos3D_rlx = self._init_pos3D - self._tag2D_rlx = self._tag2D - self._tag3D_rlx = self._tag3D + + def _compute_wall_normals(self, backend="scipy", **kwargs): + if self.cfg.case.dim == 2: + wall_part_fn = self._init_walls_2d + elif self.cfg.case.dim == 3: + wall_part_fn = self._init_walls_3d + else: + raise NotImplementedError("1D wall BCs not yet implemented") + + if backend == "scipy": + # If one makes `tag` static (-> `self.tag`), this function can be jitted. + # But it is significatly slower than `backend="jax"` due to `pure_callback`. + def body(r, tag): + return compute_nws_scipy( + r, + tag, + self.cfg.case.dx, + self.cfg.solver.n_walls, + self.offset_vec, + wall_part_fn, + ) + elif backend == "jax": + # This implementation is used in the integrator when having moving walls. + body = compute_nws_jax_wrapper( + state0=kwargs["state0"], + dx=self.cfg.case.dx, + n_walls=self.cfg.solver.n_walls, + offset_vec=self.offset_vec, + box_size=kwargs["box_size"], + pbc=self.cfg.case.pbc, + cfg_nl=self.cfg.nl, + displacement_fn=kwargs["disp_fn"], + wall_part_fn=wall_part_fn, + ) + + return body def set_relaxation(Case, cfg): @@ -318,8 +377,6 @@ def __init__(self, cfg): self._init_pos3D = self._init_pos3D_rlx self._box_size2D = self._box_size2D_rlx self._box_size3D = self._box_size3D_rlx - self._tag2D = self._tag2D_rlx - self._tag3D = self._tag3D_rlx def _init_velocity2D(self, r): return jnp.zeros_like(r) diff --git a/jax_sph/defaults.py b/jax_sph/defaults.py index dac4b09..4574c2b 100644 --- a/jax_sph/defaults.py +++ b/jax_sph/defaults.py @@ -86,6 +86,8 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: cfg.solver.eta_limiter = 3 # Thermal conductivity (non-dimensional) cfg.solver.kappa = 0 + # Number of wall boundary particle layers + cfg.solver.n_walls = 3 # Whether to apply the heat conduction term cfg.solver.heat_conduction = False # Whether to apply boundaty conditions diff --git a/jax_sph/integrator.py b/jax_sph/integrator.py index b82159b..98e6430 100644 --- a/jax_sph/integrator.py +++ b/jax_sph/integrator.py @@ -1,12 +1,13 @@ """Integrator schemes.""" - from typing import Callable, Dict from jax_sph.utils import Tag -def si_euler(tvf: float, model: Callable, shift_fn: Callable, bc_fn: Callable): +def si_euler( + tvf: float, model: Callable, shift_fn: Callable, bc_fn: Callable, nw_fn: Callable +): """Semi-implicit Euler integrator including Transport Velocity. The integrator advances the state of the system following the steps: @@ -28,6 +29,10 @@ def advance(dt: float, state: Dict, neighbors): # 2. Integrate position with velocity v state["r"] = shift_fn(state["r"], 1.0 * dt * state["v"]) + # recompute wall normals if needed + if nw_fn is not None: + state["nw"] = nw_fn(state["r"]) + # 3. Update neighbor list # The displacment and shift function from JAX MD are used for computing the diff --git a/jax_sph/simulate.py b/jax_sph/simulate.py index 0895d6c..5395666 100644 --- a/jax_sph/simulate.py +++ b/jax_sph/simulate.py @@ -38,6 +38,7 @@ def simulate(cfg: DictConfig): state, g_ext_fn, bc_fn, + nw_fn, eos_fn, key, displacement_fn, @@ -84,7 +85,7 @@ def simulate(cfg: DictConfig): neighbors = neighbor_fn.allocate(state["r"], num_particles=num_particles) # Instantiate advance function for our use case - advance = si_euler(cfg.solver.tvf, forward, shift_fn, bc_fn) + advance = si_euler(cfg.solver.tvf, forward, shift_fn, bc_fn, nw_fn) advance = advance if cfg.no_jit else jit(advance) diff --git a/jax_sph/solver.py b/jax_sph/solver.py index 02c84bd..760d593 100644 --- a/jax_sph/solver.py +++ b/jax_sph/solver.py @@ -311,7 +311,7 @@ def gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos): particle hydrodynamics", Adami, Hu, Adams, 2012 """ - def gwbc_fn(temperature, rho, tag, u, v, p, g_ext, i_s, j_s, w_dist, dr_i_j, N): + def gwbc_fn(temperature, rho, tag, u, v, p, g_ext, i_s, j_s, w_dist, dr_i_j, nw, N): mask_bc = jnp.isin(tag, wall_tags) def no_slip_bc_fn(x): @@ -325,16 +325,16 @@ def no_slip_bc_fn(x): x = jnp.where(mask_bc[:, None], 2 * x - x_wall, x) return x - def free_slip_bc_fn(x): + def free_slip_bc_fn(x, wall_inner_normals): # # normal vectors pointing from fluid to wall # (1) implement via summing over fluid particles - wall_inner = ops.segment_sum(dr_i_j * mask_j_s_fluid[:, None], i_s, N) - # (2) implement using color gradient. Requires 2*rc thick wall - # wall_inner = - ops.segment_sum(dr_i_j*mask_j_s_wall[:, None], i_s, N) + # wall_inner = ops.segment_sum(dr_i_j * mask_j_s_fluid[:, None], i_s, N) + # # (2) implement using color gradient. Requires 2*rc thick wall + # # wall_inner = - ops.segment_sum(dr_i_j*mask_j_s_wall[:, None], i_s, N) - normalization = jnp.sqrt((wall_inner**2).sum(axis=1, keepdims=True)) - wall_inner_normals = wall_inner / (normalization + EPS) - wall_inner_normals = jnp.where(mask_bc[:, None], wall_inner_normals, 0.0) + # normalization = jnp.sqrt((wall_inner**2).sum(axis=1, keepdims=True)) + # wall_inner_normals = wall_inner / (normalization + EPS) + # wall_inner_normals = jnp.where(mask_bc[:, None], wall_inner_normals, 0.0) # for boundary particles, sum over fluid velocities x_wall_unnorm = ops.segment_sum(w_j_s_fluid[:, None] * x[j_s], i_s, N) @@ -357,8 +357,8 @@ def free_slip_bc_fn(x): if is_free_slip: # free-slip boundary - ignore viscous interactions with wall - u = free_slip_bc_fn(u) - v = free_slip_bc_fn(v) + u = free_slip_bc_fn(u, -nw) + v = free_slip_bc_fn(v, -nw) else: # no-slip boundary condition u = no_slip_bc_fn(u) @@ -558,7 +558,7 @@ def forward(state, neighbors): r, tag, mass, eta = state["r"], state["tag"], state["mass"], state["eta"] u, v, dudt, dvdt = state["u"], state["v"], state["dudt"], state["dvdt"] rho, drhodt, p = state["rho"], state["drhodt"], state["p"] - kappa, Cp = state["kappa"], state["Cp"] + nw, kappa, Cp = state["nw"], state["kappa"], state["Cp"] temperature, dTdt = state["T"], state["dTdt"] N = len(r) @@ -582,23 +582,11 @@ def forward(state, neighbors): wall_mask = jnp.where(jnp.isin(tag, wall_tags), 1.0, 0.0) fluid_mask = jnp.where(tag == Tag.FLUID, 1.0, 0.0) - # calculate normal vector of wall boundaries - temp = vmap(self._wall_phi_vec)( - rho[j_s], mass[j_s], dr_i_j, dist, wall_mask[j_s], wall_mask[i_s] - ) - phi = ops.segment_sum(temp, i_s, N) - - # compute normal vector for boundary particles eq. (15), Zhang (2017) - n_w = ( - phi - / (jnp.linalg.norm(phi, ord=2, axis=1) + EPS)[:, None] - * wall_mask[:, None] - ) - n_w = jnp.where(jnp.absolute(n_w) < EPS, 0.0, n_w) - ##### Riemann velocity BCs if self.is_bc_trick and (self.solver == "RIE"): u_tilde = self._riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N) + else: + u_tilde = u ##### Density summation or evolution @@ -624,7 +612,7 @@ def forward(state, neighbors): dr_i_j, dist, wall_mask[j_s], - n_w[j_s], + nw[j_s], g_ext[i_s], u_tilde[j_s], ) @@ -646,7 +634,19 @@ def forward(state, neighbors): if self.is_bc_trick and (self.solver == "SPH"): p, rho, u, v, temperature = self._gwbc_fn( - temperature, rho, tag, u, v, p, g_ext, i_s, j_s, w_dist, dr_i_j, N + temperature, + rho, + tag, + u, + v, + p, + g_ext, + i_s, + j_s, + w_dist, + dr_i_j, + nw, + N, ) elif self.is_bc_trick and (self.solver == "RIE"): mask = self._free_weight(fluid_mask[i_s], tag[i_s]) @@ -714,7 +714,7 @@ def forward(state, neighbors): eta[j_s], wall_mask[j_s], mask, - n_w[j_s], + nw[j_s], g_ext[i_s], u_tilde[j_s], ) @@ -754,6 +754,7 @@ def forward(state, neighbors): "T": temperature, "kappa": kappa, "Cp": Cp, + "nw": nw, } return state diff --git a/jax_sph/utils.py b/jax_sph/utils.py index c855b87..d0023b5 100644 --- a/jax_sph/utils.py +++ b/jax_sph/utils.py @@ -1,7 +1,7 @@ """General jax-sph utils.""" import enum -from typing import Dict +from typing import Callable, Dict import jax import jax.numpy as jnp @@ -9,9 +9,11 @@ from jax import ops, vmap from numpy import array from omegaconf import DictConfig +from scipy.spatial import KDTree from jax_sph.io_state import read_h5 from jax_sph.jax_md import partition, space +from jax_sph.jax_md.partition import Dense from jax_sph.kernel import QuinticKernel EPS = jnp.finfo(float).eps @@ -51,27 +53,71 @@ def pos_init_cartesian_3d(box_size: array, dx: float): return r -def pos_box_2d(L: float, H: float, dx: float, num_wall_layers: int = 3): +def pos_box_2d(fluid_box: array, dx: float, n_walls: int = 3): """Create an empty box of particles in 2D. - The box is of size (L + num_wall_layers * dx) x (H + num_wall_layers * dx). - The inner part of the box starts at (num_wall_layers * dx, num_wall_layers * dx). + fluid_box is an array of the form: [L, H] + The box is of size (L + n_walls * dx) x (H + n_walls * dx). + The inner part of the box starts at (n_walls * dx, n_walls * dx). """ - dx3 = num_wall_layers * dx + # thickness of wall particles + dxn = n_walls * dx + # horizontal and vertical blocks - vertical = pos_init_cartesian_2d(np.array([dx3, H + 2 * dx3]), dx) - horiz = pos_init_cartesian_2d(np.array([L, dx3]), dx) + vertical = pos_init_cartesian_2d(np.array([dxn, fluid_box[1] + 2 * dxn]), dx) + horiz = pos_init_cartesian_2d(np.array([fluid_box[0], dxn]), dx) # wall: left, bottom, right, top wall_l = vertical.copy() - wall_b = horiz.copy() + np.array([dx3, 0.0]) - wall_r = vertical.copy() + np.array([L + dx3, 0.0]) - wall_t = horiz.copy() + np.array([dx3, H + dx3]) + wall_b = horiz.copy() + np.array([dxn, 0.0]) + wall_r = vertical.copy() + np.array([fluid_box[0] + dxn, 0.0]) + wall_t = horiz.copy() + np.array([dxn, fluid_box[1] + dxn]) res = jnp.concatenate([wall_l, wall_b, wall_r, wall_t]) return res +def pos_box_3d(fluid_box: array, dx: float, n_walls: int = 3, z_periodic: bool = True): + """Create an z-periodic empty box of particles in 3D. + + fluid_box is an array of the form: [L, H, D] + The box is of size (L + n_walls * dx) x (H + n_walls * dx) x D. + The inner part of the box starts at (n_walls * dx, n_walls * dx). + z_periodic states whether the box is periodic in z-direction. + """ + # thickness of wall particles + dxn = n_walls * dx + + # horizontal and vertical blocks + vertical = pos_init_cartesian_3d( + np.array([dxn, fluid_box[1] + 2 * dxn, fluid_box[2]]), dx + ) + horiz = pos_init_cartesian_3d(np.array([fluid_box[0], dxn, fluid_box[2]]), dx) + + # wall: left, bottom, right, top + wall_l = vertical.copy() + wall_b = horiz.copy() + np.array([dxn, 0.0, 0.0]) + wall_r = vertical.copy() + np.array([fluid_box[0] + dxn, 0.0, 0.0]) + wall_t = horiz.copy() + np.array([dxn, fluid_box[1] + dxn, 0.0]) + + res = jnp.concatenate([wall_l, wall_b, wall_r, wall_t]) + + # add walls in z-direction + if not z_periodic: + res += np.array([0.0, 0.0, dxn]) + # front block + front = pos_init_cartesian_3d( + np.array([fluid_box[0] + 2 * dxn, fluid_box[1] + 2 * dxn, dxn]), dx + ) + + # wall: front, end + wall_f = front.copy() + wall_e = front.copy() + np.array([0.0, 0.0, fluid_box[2] + dxn]) + res = jnp.concatenate([res, wall_f, wall_e]) + + return res + + def get_noise_masked(shape: tuple, mask: array, key: jax.random.PRNGKey, std: float): """Generate Gaussian noise with `std` where `mask` is True.""" noise = std * jax.random.normal(key, shape) @@ -103,7 +149,7 @@ def get_array_stats(state: Dict, var: str = "u", operation="max"): if jnp.size(state[var].shape) > 1: val_array = jnp.sqrt(jnp.square(state[var]).sum(axis=1)) else: - val_array = state[var] # TODO: check difference to jnp.absolute(state[var]) + val_array = state[var] return func(val_array) @@ -120,6 +166,115 @@ def get_stats(state: Dict, props: list, dx: float): return res +def compute_nws_scipy(r, tag, dx, n_walls, offset_vec, wall_part_fn): + """Computes the normal vectors of all wall boundaries. Jit-able pure_callback.""" + + dx_fac = 5 + + # operate only on wall particles, i.e. remove fluid + r_walls = r[np.isin(tag, wall_tags)] + + # align fluid to [0, 0] + r_aligned = r_walls - offset_vec + + # define fine layer of wall BC partilces and position them accordingly + layer = wall_part_fn(dx / dx_fac, 1) - offset_vec / n_walls / dx_fac + + # match thin layer to particles + tree = KDTree(layer) + dist, match_idx = tree.query(r_aligned, k=1) + dr = layer[match_idx] - r_aligned + nw_walls = dr / (dist[:, None] + EPS) + nw_walls = jnp.asarray(nw_walls, dtype=r.dtype) + + # compute normal vectors + nw = jnp.zeros_like(r) + nw = nw.at[np.isin(tag, wall_tags)].set(nw_walls) + + return nw + + +def compute_nws_jax_wrapper( + state0: Dict, + dx: float, + n_walls: int, + offset_vec: jax.Array, + box_size: jax.Array, + pbc: jax.Array, + cfg_nl: DictConfig, + displacement_fn: Callable, + wall_part_fn: Callable, +): + """Compute wall normal vectors from wall to fluid. Jit-able JAX implementation. + + For the particles from `r_walls`, find the closest particle from `layer` + and compute the normal vector from each `r_walls` particle. + """ + r = state0["r"] + tag = state0["tag"] + + # operate only on wall particles, i.e. remove fluid + r_walls = r[np.isin(tag, wall_tags)] - offset_vec + + # discretize wall with one layer of 5x smaller particles + dx_fac = 5 + offset = offset_vec / n_walls / dx_fac + layer = wall_part_fn(dx / dx_fac, 1) - offset + + # construct a neighbor list over both point clouds + r_full = jnp.concatenate([r_walls, layer], axis=0) + + neighbor_fn = partition.neighbor_list( + displacement_fn, + box_size, + r_cutoff=dx * n_walls * 2.0**0.5 * 1.01, + backend=cfg_nl.backend, + capacity_multiplier=1.25, + mask_self=False, + format=Dense, + num_particles_max=r_full.shape[0], + num_partitions=cfg_nl.num_partitions, + pbc=np.array(pbc), + ) + num_particles = len(r_full) + neighbors = neighbor_fn.allocate(r_full, num_particles=num_particles) + + # jit-able function + def body(r: jax.Array): + r_walls = r[np.isin(tag, wall_tags)] - offset_vec + r_full = jnp.concatenate([r_walls, layer], axis=0) + + nbrs = neighbors.update(r_full, num_particles=num_particles) + + # get the relevant entries from the dense neighbor list + idx = nbrs.idx # dense list: [[0, 1, 5], [0, 1, 3], [2, 3, 6], ...] + idx = idx[: len(r_walls)] # only the wall particle neighbors + mask_to_layer = idx > len(r_walls) # mask toward `layer` particles + idx = jnp.where(mask_to_layer, idx, len(r_full)) # get rid of unwanted edges + + # compute distances `r_wall` and `layer` particles and set others to infinity + r_i_s = r_full[idx] + dr_i_j = vmap(vmap(displacement_fn, in_axes=(0, None)))(r_i_s, r_walls) + dist = space.distance(dr_i_j) + mask_real = idx != len(r_full) # identify padding entries + dist = jnp.where(mask_real, dist, jnp.inf) + + # find closest `layer` particle for each `r_wall` particle and normalize + # displacement vector between the two to use it as the normal vector + idx_closest = jnp.argmin(dist, axis=1) + nw_walls = dr_i_j[jnp.arange(len(r_walls)), idx_closest] + nw_walls /= (dist[jnp.arange(len(r_walls)), idx_closest] + EPS)[:, None] + nw_walls = jnp.asarray(nw_walls, dtype=r.dtype) + + # update normals only of wall particles + nw = jnp.zeros_like(r) + nw = nw.at[np.isin(tag, wall_tags)].set(nw_walls) + + return nw + + return body + + class Logger: """Logger for printing stats to stdout.""" diff --git a/notebooks/tutorial.ipynb b/notebooks/tutorial.ipynb index 87153dd..ab0b8a7 100644 --- a/notebooks/tutorial.ipynb +++ b/notebooks/tutorial.ipynb @@ -33,15 +33,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "import time\n", @@ -60,7 +52,7 @@ "from jax_sph.io_state import io_setup, read_h5, write_state\n", "from jax_sph.jax_md.partition import Sparse\n", "from jax_sph.solver import WCSPH\n", - "from jax_sph.utils import Logger, Tag\n", + "from jax_sph.utils import Logger, Tag, pos_init_cartesian_2d\n", "from jax_sph.visualize import plt_ekin\n" ] }, @@ -107,6 +99,9 @@ " def __init__(self, cfg: DictConfig):\n", " super().__init__(cfg)\n", "\n", + " # define offset vector\n", + " self.offset_vec = self._offset_vec()\n", + "\n", " # relaxation configurations\n", " if self.case.mode == \"rlx\":\n", " self._set_default_rlx()\n", @@ -116,55 +111,98 @@ " self._init_pos2D = self._get_relaxed_r0\n", " self._init_pos3D = self._get_relaxed_r0\n", "\n", - " def _box_size2D(self):\n", - " dx = self.case.dx\n", - " return np.array([1, 0.2 + 6 * dx])\n", + " def _box_size2D(self, n_walls):\n", + " dx2n = self.case.dx * n_walls * 2\n", + " sp = self.special\n", + " return np.array([sp.L, sp.H + dx2n])\n", + "\n", + " def _box_size3D(self, n_walls):\n", + " dx2n = self.case.dx * n_walls * 2\n", + " sp = self.special\n", + " return np.array([sp.L, sp.H + dx2n, 0.5])\n", + "\n", + " def _init_walls_2d(self, dx, n_walls):\n", + " sp = self.special\n", "\n", - " def _box_size3D(self):\n", - " dx = self.case.dx\n", - " return np.array([1, 0.2 + 6 * dx, 0.5])\n", + " # thickness of wall particles\n", + " dxn = dx * n_walls\n", "\n", - " def _tag2D(self, r):\n", - " dx3 = 3 * self.case.dx\n", - " _box_size = self._box_size2D()\n", - " tag = jnp.full(len(r), Tag.FLUID, dtype=int)\n", + " # horizontal and vertical blocks\n", + " horiz = pos_init_cartesian_2d(np.array([sp.L, dxn]), dx)\n", "\n", - " mask_no_slip_wall = (r[:, 1] < dx3) + (\n", - " r[:, 1] > (_box_size[1] - 6 * self.case.dx) + dx3\n", + " # wall: bottom, top\n", + " wall_b = horiz.copy()\n", + " wall_t = horiz.copy() + np.array([0.0, sp.H + dxn])\n", + "\n", + " rw = np.concatenate([wall_b, wall_t])\n", + " return rw\n", + "\n", + " def _init_walls_3d(self, dx, n_walls):\n", + " pass\n", + "\n", + " def _init_pos2D(self, box_size, dx, n_walls):\n", + " sp = self.special\n", + "\n", + " # initialize fluid phase\n", + " r_f = np.array([0.0, 1.0]) * n_walls * dx + pos_init_cartesian_2d(\n", + " np.array([sp.L, sp.H]), dx\n", " )\n", + "\n", + " # initialize walls\n", + " r_w = self._init_walls_2d(dx, n_walls)\n", + "\n", + " # set tags\n", + " tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)\n", + " tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)\n", + "\n", + " r = np.concatenate([r_w, r_f])\n", + " tag = np.concatenate([tag_w, tag_f])\n", + "\n", + " # set thermal tags\n", + " _box_size = self._box_size2D(n_walls)\n", " mask_hot_wall = (\n", - " (r[:, 1] < dx3)\n", + " (r[:, 1] < dx * n_walls)\n", " * (r[:, 0] < (_box_size[0] / 2) + self.special.hot_wall_half_width)\n", " * (r[:, 0] > (_box_size[0] / 2) - self.special.hot_wall_half_width)\n", " )\n", - " tag = jnp.where(mask_no_slip_wall, Tag.SOLID_WALL, tag)\n", " tag = jnp.where(mask_hot_wall, Tag.DIRICHLET_WALL, tag)\n", - " return tag\n", "\n", - " def _tag3D(self, r):\n", - " return self._tag2D(r)\n", + " return r, tag\n", + "\n", + " def _init_pos3D(self, box_size, dx, n_walls):\n", + " pass\n", + "\n", + " def _offset_vec(self):\n", + " dim = self.cfg.case.dim\n", + " if dim == 2:\n", + " res = np.array([0.0, 1.0]) * self.cfg.solver.n_walls * self.cfg.case.dx\n", + " elif dim == 3:\n", + " res = np.array([0.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx\n", + " return res\n", "\n", " def _init_velocity2D(self, r):\n", " return jnp.zeros_like(r)\n", "\n", " def _init_velocity3D(self, r):\n", - " return jnp.zeros_like(r)\n", + " pass\n", "\n", " def _external_acceleration_fn(self, r):\n", - " dx3 = 3 * self.case.dx\n", + " n_walls = self.cfg.solver.n_walls\n", + " dxn = n_walls * self.case.dx\n", " res = jnp.zeros_like(r)\n", " x_force = jnp.ones((len(r)))\n", - " box_size = self._box_size2D()\n", - " fluid_mask = (r[:, 1] < box_size[1] - dx3) * (r[:, 1] > dx3)\n", + " box_size = self._box_size2D(n_walls)\n", + " fluid_mask = (r[:, 1] < box_size[1] - dxn) * (r[:, 1] > dxn)\n", " x_force = jnp.where(fluid_mask, x_force, 0)\n", " res = res.at[:, 0].set(x_force)\n", " return res * self.case.g_ext_magnitude\n", "\n", " def _boundary_conditions_fn(self, state):\n", + " n_walls = self.cfg.solver.n_walls\n", " mask_fluid = state[\"tag\"] == Tag.FLUID\n", "\n", " # set incoming fluid temperature to reference_temperature\n", - " mask_inflow = mask_fluid * (state[\"r\"][:, 0] < 3 * self.case.dx)\n", + " mask_inflow = mask_fluid * (state[\"r\"][:, 0] < n_walls * self.case.dx)\n", " state[\"T\"] = jnp.where(mask_inflow, self.case.T_ref, state[\"T\"])\n", " state[\"dTdt\"] = jnp.where(mask_inflow, 0.0, state[\"dTdt\"])\n", "\n", @@ -188,7 +226,7 @@ " # set outlet temperature gradients to zero to avoid interaction with inflow\n", " # bounds[0][1] is the x-coordinate of the outlet\n", " mask_outflow = mask_fluid * (\n", - " state[\"r\"][:, 0] > self.case.bounds[0][1] - 3 * self.case.dx\n", + " state[\"r\"][:, 0] > self.case.bounds[0][1] - n_walls * self.case.dx\n", " )\n", " state[\"dTdt\"] = jnp.where(mask_outflow, 0.0, state[\"dTdt\"])\n", "\n", @@ -215,7 +253,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And now add the other needed entries." + "And now add the other needed entries. For a full list of all parameters, see [`defaults.py`](../jax_sph/defaults.py) or the [list of defaults](https://jax-sph.readthedocs.io/en/latest/pages/defaults.html) in our documentation." ] }, { @@ -232,10 +270,13 @@ "cfg.case.g_ext_magnitude = 2.3 # external force scale\n", "cfg.case.kappa_ref = 7.313 # thermal conductivity at 50°C\n", "cfg.case.Cp_ref = 305.27 # heat capacity at 50°C\n", + "### case specific arguments in `special`\n", "# nondimensionalized temperature corresponding to 100°C -> 1.23\n", "cfg.case.special.hot_wall_temperature = 1.23\n", "# define the location of the hot section\n", "cfg.case.special.hot_wall_half_width = 0.25\n", + "cfg.case.special.L = 1.0 # channel length\n", + "cfg.case.special.H = 0.2 # channel width\n", "\n", "### numerical setup arguments\n", "cfg.solver.t_end = 1.5 # simulation end time\n", @@ -270,7 +311,7 @@ "config: null\n", "seed: 123\n", "no_jit: false\n", - "gpu: -1\n", + "gpu: 0\n", "dtype: float64\n", "xla_mem_fraction: 0.75\n", "case:\n", @@ -294,6 +335,8 @@ " special:\n", " hot_wall_temperature: 1.23\n", " hot_wall_half_width: 0.25\n", + " L: 1.0\n", + " H: 0.2\n", "solver:\n", " name: SPH\n", " tvf: 0.0\n", @@ -306,6 +349,7 @@ " free_slip: false\n", " eta_limiter: 3\n", " kappa: 0\n", + " n_walls: 3\n", " heat_conduction: true\n", " is_bc_trick: true\n", "kernel:\n", @@ -363,11 +407,12 @@ "\n", "Here, we unpack `jax_sph.simulate.simulate` and give some more details.\n", "\n", - "The `case.initialize` method not only defines the initial simulation state with positions, velocities, etc., but also initializes the following:\n", - "* external force function `g_ext_fn` from Python case file\n", - "* boundary conditions `bc_fn` from Python case file\n", - "* equation of state `eos_fn` which defines pressure and density as functions of each other\n", - "* displacement and shift function respecting potential periodic box boundaries" + "The `case.initialize` method not only defines the initial simulation state with positions, velocities, etc., but also initializes the following functions:\n", + "* External force function `g_ext_fn` from Python case file.\n", + "* Boundary conditions `bc_fn` from Python case file.\n", + "* Wall normal vector computation `nw_fn` in case of moving walls.\n", + "* Equation of state `eos_fn` which defines pressure and density as functions of each other.\n", + "* Displacement and shift function respecting potential periodic box boundaries." ] }, { @@ -400,6 +445,7 @@ " state,\n", " g_ext_fn,\n", " bc_fn,\n", + " nw_fn,\n", " eos_fn,\n", " key,\n", " displacement_fn,\n", @@ -500,7 +546,7 @@ "outputs": [], "source": [ "# Instantiate advance function for our use case\n", - "advance = si_euler(cfg.solver.tvf, forward, shift_fn, bc_fn)\n", + "advance = si_euler(cfg.solver.tvf, forward, shift_fn, bc_fn, nw_fn)\n", "advance = jit(advance)" ] }, @@ -547,40 +593,40 @@ "output_type": "stream", "text": [ "0000/3300, t=0.0005, Ekin=0.00000, u_max=0.00000\n", - "0100/3300, t=0.0459, Ekin=0.00162, u_max=0.31360\n", - "0200/3300, t=0.0914, Ekin=0.00312, u_max=0.25906\n", - "0300/3300, t=0.1368, Ekin=0.00605, u_max=0.32460\n", - "0400/3300, t=0.1823, Ekin=0.00959, u_max=0.40505\n", - "0500/3300, t=0.2277, Ekin=0.01343, u_max=0.48048\n", - "0600/3300, t=0.2732, Ekin=0.01748, u_max=0.55084\n", - "0700/3300, t=0.3186, Ekin=0.02150, u_max=0.61357\n", - "0800/3300, t=0.3641, Ekin=0.02547, u_max=0.67135\n", - "0900/3300, t=0.4095, Ekin=0.02930, u_max=0.72199\n", - "1000/3300, t=0.4550, Ekin=0.03297, u_max=0.76502\n", - "1100/3300, t=0.5005, Ekin=0.03641, u_max=0.80508\n", - "1200/3300, t=0.5459, Ekin=0.03965, u_max=0.83981\n", - "1300/3300, t=0.5914, Ekin=0.04263, u_max=0.87329\n", - "1400/3300, t=0.6368, Ekin=0.04546, u_max=0.90626\n", - "1500/3300, t=0.6823, Ekin=0.04807, u_max=0.93373\n", - "1600/3300, t=0.7277, Ekin=0.05041, u_max=0.95674\n", - "1700/3300, t=0.7732, Ekin=0.05262, u_max=0.97548\n", - "1800/3300, t=0.8186, Ekin=0.05458, u_max=0.99214\n", - "1900/3300, t=0.8641, Ekin=0.05639, u_max=1.00931\n", - "2000/3300, t=0.9095, Ekin=0.05800, u_max=1.02836\n", - "2100/3300, t=0.9550, Ekin=0.05955, u_max=1.04335\n", - "2200/3300, t=1.0005, Ekin=0.06088, u_max=1.05429\n", - "2300/3300, t=1.0459, Ekin=0.06214, u_max=1.06249\n", - "2400/3300, t=1.0914, Ekin=0.06317, u_max=1.06985\n", - "2500/3300, t=1.1368, Ekin=0.06418, u_max=1.08156\n", - "2600/3300, t=1.1823, Ekin=0.06511, u_max=1.09160\n", - "2700/3300, t=1.2277, Ekin=0.06593, u_max=1.09884\n", - "2800/3300, t=1.2732, Ekin=0.06667, u_max=1.10251\n", - "2900/3300, t=1.3186, Ekin=0.06727, u_max=1.10529\n", - "3000/3300, t=1.3641, Ekin=0.06786, u_max=1.11242\n", - "3100/3300, t=1.4095, Ekin=0.06844, u_max=1.11944\n", - "3200/3300, t=1.4550, Ekin=0.06891, u_max=1.12419\n", - "3300/3300, t=1.5005, Ekin=0.06936, u_max=1.12494\n", - "time: 43.21 s\n" + "0100/3300, t=0.0459, Ekin=0.00155, u_max=0.28668\n", + "0200/3300, t=0.0914, Ekin=0.00312, u_max=0.27229\n", + "0300/3300, t=0.1368, Ekin=0.00604, u_max=0.32408\n", + "0400/3300, t=0.1823, Ekin=0.00958, u_max=0.40263\n", + "0500/3300, t=0.2277, Ekin=0.01342, u_max=0.48013\n", + "0600/3300, t=0.2732, Ekin=0.01747, u_max=0.55052\n", + "0700/3300, t=0.3186, Ekin=0.02150, u_max=0.61441\n", + "0800/3300, t=0.3641, Ekin=0.02546, u_max=0.67120\n", + "0900/3300, t=0.4095, Ekin=0.02929, u_max=0.72111\n", + "1000/3300, t=0.4550, Ekin=0.03296, u_max=0.76472\n", + "1100/3300, t=0.5005, Ekin=0.03640, u_max=0.80485\n", + "1200/3300, t=0.5459, Ekin=0.03964, u_max=0.83979\n", + "1300/3300, t=0.5914, Ekin=0.04262, u_max=0.87452\n", + "1400/3300, t=0.6368, Ekin=0.04545, u_max=0.90765\n", + "1500/3300, t=0.6823, Ekin=0.04806, u_max=0.93497\n", + "1600/3300, t=0.7277, Ekin=0.05040, u_max=0.95750\n", + "1700/3300, t=0.7732, Ekin=0.05261, u_max=0.97556\n", + "1800/3300, t=0.8186, Ekin=0.05457, u_max=0.99333\n", + "1900/3300, t=0.8641, Ekin=0.05637, u_max=1.00955\n", + "2000/3300, t=0.9095, Ekin=0.05799, u_max=1.02807\n", + "2100/3300, t=0.9550, Ekin=0.05954, u_max=1.04281\n", + "2200/3300, t=1.0005, Ekin=0.06088, u_max=1.05362\n", + "2300/3300, t=1.0459, Ekin=0.06213, u_max=1.06218\n", + "2400/3300, t=1.0914, Ekin=0.06316, u_max=1.06955\n", + "2500/3300, t=1.1368, Ekin=0.06417, u_max=1.08134\n", + "2600/3300, t=1.1823, Ekin=0.06511, u_max=1.09179\n", + "2700/3300, t=1.2277, Ekin=0.06592, u_max=1.09803\n", + "2800/3300, t=1.2732, Ekin=0.06667, u_max=1.10173\n", + "2900/3300, t=1.3186, Ekin=0.06727, u_max=1.10500\n", + "3000/3300, t=1.3641, Ekin=0.06785, u_max=1.11255\n", + "3100/3300, t=1.4095, Ekin=0.06844, u_max=1.11976\n", + "3200/3300, t=1.4550, Ekin=0.06891, u_max=1.12356\n", + "3300/3300, t=1.5005, Ekin=0.06935, u_max=1.12462\n", + "time: 15.92 s\n" ] } ], @@ -739,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ diff --git a/poetry.lock b/poetry.lock index 9a8956a..91e0bb5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -843,19 +843,19 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pa [[package]] name = "jax" -version = "0.4.28" +version = "0.4.29" description = "Differentiate, compile, and transform Numpy code." optional = false python-versions = ">=3.9" files = [ - {file = "jax-0.4.28-py3-none-any.whl", hash = "sha256:6a181e6b5a5b1140e19cdd2d5c4aa779e4cb4ec627757b918be322d8e81035ba"}, - {file = "jax-0.4.28.tar.gz", hash = "sha256:dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9"}, + {file = "jax-0.4.29-py3-none-any.whl", hash = "sha256:cfdc594d133d7dfba2ec19bc10742ffaa0e8827ead29be00d3ec4215a3f7892e"}, + {file = "jax-0.4.29.tar.gz", hash = "sha256:12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186"}, ] [package.dependencies] importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} -jaxlib = {version = "0.4.28", optional = true, markers = "extra == \"cpu\""} -ml-dtypes = ">=0.2.0" +jaxlib = {version = "0.4.29", optional = true, markers = "extra == \"cpu\""} +ml-dtypes = ">=0.4.0" numpy = [ {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.22", markers = "python_version < \"3.11\""}, @@ -864,53 +864,52 @@ opt-einsum = "*" scipy = ">=1.9" [package.extras] -australis = ["protobuf (>=3.13,<4)"] -ci = ["jaxlib (==0.4.27)"] -cpu = ["jaxlib (==0.4.28)"] -cuda = ["jaxlib (==0.4.28+cuda12.cudnn89)"] -cuda12 = ["jax-cuda12-plugin (==0.4.28)", "jaxlib (==0.4.28)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] -cuda12-cudnn89 = ["jaxlib (==0.4.28+cuda12.cudnn89)"] -cuda12-local = ["jaxlib (==0.4.28+cuda12.cudnn89)"] -cuda12-pip = ["jaxlib (==0.4.28+cuda12.cudnn89)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +ci = ["jaxlib (==0.4.28)"] +cpu = ["jaxlib (==0.4.29)"] +cuda = ["jaxlib (==0.4.29+cuda12.cudnn91)"] +cuda12 = ["jax-cuda12-plugin (==0.4.29)", "jaxlib (==0.4.29)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=9.0,<10.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +cuda12-cudnn91 = ["jaxlib (==0.4.29+cuda12.cudnn91)"] +cuda12-local = ["jaxlib (==0.4.29+cuda12.cudnn91)"] +cuda12-pip = ["jaxlib (==0.4.29+cuda12.cudnn91)", "nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=9.0,<10.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] minimum-jaxlib = ["jaxlib (==0.4.27)"] -tpu = ["jaxlib (==0.4.28)", "libtpu-nightly (==0.1.dev20240508)", "requests"] +tpu = ["jaxlib (==0.4.29)", "libtpu-nightly (==0.1.dev20240609)", "requests"] [[package]] name = "jaxlib" -version = "0.4.28" +version = "0.4.29" description = "XLA library for JAX" optional = false python-versions = ">=3.9" files = [ - {file = "jaxlib-0.4.28-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:a421d237f8c25d2850166d334603c673ddb9b6c26f52bc496704b8782297bd66"}, - {file = "jaxlib-0.4.28-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f038e68bd10d1a3554722b0bbe36e6a448384437a75aa9d283f696f0ed9f8c09"}, - {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:fabe77c174e9e196e9373097cefbb67e00c7e5f9d864583a7cfcf9dabd2429b6"}, - {file = "jaxlib-0.4.28-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:e3bcdc6f8e60f8554f415c14d930134e602e3ca33c38e546274fd545f875769b"}, - {file = "jaxlib-0.4.28-cp310-cp310-win_amd64.whl", hash = "sha256:a8b31c0e5eea36b7915696b9be40ea8646edc395a3e5437bf7ef26b7239a567a"}, - {file = "jaxlib-0.4.28-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2ff8290edc7b92c7eae52517f65492633e267b2e9067bad3e4c323d213e77cf5"}, - {file = "jaxlib-0.4.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:793857faf37f371cafe752fea5fc811f435e43b8fb4b502058444a7f5eccf829"}, - {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b41a6b0d506c09f86a18ecc05bd376f072b548af89c333107e49bb0c09c1a3f8"}, - {file = "jaxlib-0.4.28-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:45ce0f3c840cff8236cff26c37f26c9ff078695f93e0c162c320c281f5041275"}, - {file = "jaxlib-0.4.28-cp311-cp311-win_amd64.whl", hash = "sha256:d4d762c3971d74e610a0e85a7ee063cea81a004b365b2a7dc65133f08b04fac5"}, - {file = "jaxlib-0.4.28-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:d6c09a545329722461af056e735146d2c8c74c22ac7426a845eb69f326b4f7a0"}, - {file = "jaxlib-0.4.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8dd8bffe3853702f63cd924da0ee25734a4d19cd5c926be033d772ba7d1c175d"}, - {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:de2e8521eb51e16e85093a42cb51a781773fa1040dcf9245d7ea160a14ee5a5b"}, - {file = "jaxlib-0.4.28-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:46a1aa857f4feee8a43fcba95c0e0ab62d40c26cc9730b6c69655908ba359f8d"}, - {file = "jaxlib-0.4.28-cp312-cp312-win_amd64.whl", hash = "sha256:eee428eac31697a070d655f1f24f6ab39ced76750d93b1de862377a52dcc2401"}, - {file = "jaxlib-0.4.28-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:4f98cc837b2b6c6dcfe0ab7ff9eb109314920946119aa3af9faa139718ff2787"}, - {file = "jaxlib-0.4.28-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b01562ec8ad75719b7d0389752489e97eb6b4dcb4c8c113be491634d5282ad3c"}, - {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:aa77a9360a395ba9faf6932df637686fb0c14ddcf4fdc1d2febe04bc88a580a6"}, - {file = "jaxlib-0.4.28-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4a56ebf05b4a4c1791699d874e072f3f808f0986b4010b14fb549a69c90ca9dc"}, - {file = "jaxlib-0.4.28-cp39-cp39-win_amd64.whl", hash = "sha256:459a4ddcc3e120904b9f13a245430d7801d707bca48925981cbdc59628057dc8"}, + {file = "jaxlib-0.4.29-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:60ec8aa2ba133a0615b0fce8e084c90c179c019793551641dd3da6526d036953"}, + {file = "jaxlib-0.4.29-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adb37f9c01a0fbdf97ab4afc7b60939c1694a5c056d8224c3d292cc253a3dc55"}, + {file = "jaxlib-0.4.29-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8b1062b804d95ddb8dbb039c48316cbaed1d6866f973ef39e03f39e452ac8a1c"}, + {file = "jaxlib-0.4.29-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:ae21b84dd08c015bf2bab9ba97fa6a1da30f9a51c35902e23c7ffe959ad2e86c"}, + {file = "jaxlib-0.4.29-cp310-cp310-win_amd64.whl", hash = "sha256:a4993ab2f91c8ee213cacc4ed341539a7980b1b231e9dac69d68b508118dc19d"}, + {file = "jaxlib-0.4.29-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1da0716c45c5b0e177d334938a09953915f8da2080fffee9366ad8a9a988f484"}, + {file = "jaxlib-0.4.29-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5da76e760be790896f7149eccafff64b800f129a282de7f7a1edc138d56ac997"}, + {file = "jaxlib-0.4.29-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:aec76c416657e25884ee1364e98e40fcedbd2235c79691026d9babf05d850ede"}, + {file = "jaxlib-0.4.29-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:5a313e94c3ae87f147d561bc61e923d75e18e2448ac8fbf150d526ecd24404b0"}, + {file = "jaxlib-0.4.29-cp311-cp311-win_amd64.whl", hash = "sha256:7bfcc35a2991c2973489333f5c07dbb1f5d0ec78ef889097534c2c5f0d0149a7"}, + {file = "jaxlib-0.4.29-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7d7eabdb21814386cfa32e09ddaee76e374126c3709989b6de5f1e49a04f5f36"}, + {file = "jaxlib-0.4.29-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ac11efc5eb7d0d25dc853efd68c18b204d071e78f762583dc9ebed84db272bf2"}, + {file = "jaxlib-0.4.29-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:08cb5b24f481f62ff432b0bbedc7e35c5d561dc42c1c8138bbf8514ea91ab17e"}, + {file = "jaxlib-0.4.29-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:a7c884c5651e5d1cc0fc57f2cf0abee4223b1f38c3e8cbd00fb142e6551cfe47"}, + {file = "jaxlib-0.4.29-cp312-cp312-win_amd64.whl", hash = "sha256:91058f1606312c42621b0a9979d0f14c0db9da6341ffd714ac5eb5e3be79c59a"}, + {file = "jaxlib-0.4.29-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:e59afd8026f43688fd0bf4f3dbcf913157c7f144d000850aaa7a88b1228a87ab"}, + {file = "jaxlib-0.4.29-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2801327384be3edab5f3adb38262206e488135d7c8e27d928ac3c6ffb71f7718"}, + {file = "jaxlib-0.4.29-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:dfce109290dbbd27176750b931bedbabb56a4f5e955f83eead2e3cdefe5108b8"}, + {file = "jaxlib-0.4.29-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:84be918201a7f06b73074ed154fc3344b02aa8736597b39397e7093807dcfc3c"}, + {file = "jaxlib-0.4.29-cp39-cp39-win_amd64.whl", hash = "sha256:9b0efd3ba45a7ee03fb91a4118099ea97aa02a96e50f4d91cc910554e226c5b9"}, ] [package.dependencies] -ml-dtypes = ">=0.2.0" +ml-dtypes = ">=0.4.0" numpy = ">=1.22" scipy = ">=1.9" [package.extras] -cuda12-pip = ["nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=8.9.2.26,<9.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] +cuda12-pip = ["nvidia-cublas-cu12 (>=12.1.3.1)", "nvidia-cuda-cupti-cu12 (>=12.1.105)", "nvidia-cuda-nvcc-cu12 (>=12.1.105)", "nvidia-cuda-runtime-cu12 (>=12.1.105)", "nvidia-cudnn-cu12 (>=9.0,<10.0)", "nvidia-cufft-cu12 (>=11.0.2.54)", "nvidia-cusolver-cu12 (>=11.4.5.107)", "nvidia-cusparse-cu12 (>=12.1.0.106)", "nvidia-nccl-cu12 (>=2.18.1)", "nvidia-nvjitlink-cu12 (>=12.1.105)"] [[package]] name = "jaxopt" @@ -1520,13 +1519,13 @@ test = ["chex", "coverage[toml]", "lineax", "matplotlib", "networkx (>=2.5)", "p [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -1788,13 +1787,13 @@ virtualenv = ">=20.10.0" [[package]] name = "prompt-toolkit" -version = "3.0.46" +version = "3.0.47" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.46-py3-none-any.whl", hash = "sha256:45abe60a8300f3c618b23c16c4bb98c6fc80af8ce8b17c7ae92db48db3ee63c1"}, - {file = "prompt_toolkit-3.0.46.tar.gz", hash = "sha256:869c50d682152336e23c4db7f74667639b5047494202ffe7670817053fd57795"}, + {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, + {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, ] [package.dependencies] @@ -2682,4 +2681,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.11" -content-hash = "9dd3f880130bab1f475f71d18e7215d861517140fb64a69ac4cd3fa76d63a129" +content-hash = "3f89c0c18fea0692378f5b81a4d8c12fa9bdc0b964e4aca6c0c7c4e2dc2e45ab" diff --git a/pyproject.toml b/pyproject.toml index da886e5..b12d55a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,13 +15,13 @@ python = ">=3.9,<=3.11" h5py = ">=3.9.0" pandas = ">=2.1.4" # for validation pyvista = ">=0.42.2" # for visualization -jax = {version = "0.4.28", extras = ["cpu"]} -jaxlib = "0.4.28" -omegaconf = "^2.3.0" +jax = {version = "0.4.29", extras = ["cpu"]} +jaxlib = "0.4.29" +omegaconf = ">=2.3.0" matscipy = ">=0.8.0" dataclasses = "0.6" # for jax-md -jraph = "^0.0.6.dev0" # for jax-md -absl-py = "^2.1.0" # for jax-md +jraph = ">=0.0.6.dev0" # for jax-md +absl-py = ">=2.1.0" # for jax-md [tool.poetry.group.dev.dependencies] pre-commit = ">=3.3.1" @@ -37,10 +37,9 @@ ipykernel = ">=6.25.1" sphinx = "7.2.6" sphinx-exec-code = "0.12" sphinx-rtd-theme = "1.3.0" -toml = "^0.10.2" +toml = ">=0.10.2" [tool.ruff] -ignore = ["F821", "E402"] exclude = [ ".git", ".venv*", @@ -52,6 +51,7 @@ show-fixes = true line-length = 88 [tool.ruff.lint] +ignore = ["F821", "E402"] select = [ "E", # pycodestyle "F", # Pyflakes diff --git a/requirements.txt b/requirements.txt index 2dd9bfd..ad56d35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ h5py --e git+https://github.com/jax-md/jax-md.git@c451353f6ddcab031f660befda256d8a4f657855#egg=jax-md -jax[cpu]==0.4.28 +jax[cpu]==0.4.29 omegaconf pandas pyvista