Skip to content

Commit

Permalink
Merge pull request #16 from tumaer/normal-vectors
Browse files Browse the repository at this point in the history
Normal vectors
  • Loading branch information
arturtoshev authored Jun 12, 2024
2 parents ae27d22 + 40ad7af commit 8952d7e
Show file tree
Hide file tree
Showing 22 changed files with 891 additions and 333 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Later, you just need to `source venv/bin/activate` to activate the environment.
If you want to use a CUDA GPU, you first need a running Nvidia driver. And then just follow the instructions [here](https://jax.readthedocs.io/en/latest/installation.html). The whole process could look like this:
```bash
source .venv/bin/activate
pip install -U "jax[cuda12]"
pip install -U "jax[cuda12]==0.4.29"
```

## Getting Started
Expand Down
83 changes: 58 additions & 25 deletions cases/cw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from omegaconf import DictConfig

from jax_sph.case_setup import SimulationSetup
from jax_sph.utils import Tag, pos_box_2d, pos_init_cartesian_2d
from jax_sph.utils import (
Tag,
pos_box_2d,
pos_box_3d,
pos_init_cartesian_2d,
pos_init_cartesian_3d,
)


class CW(SimulationSetup):
Expand All @@ -14,40 +20,67 @@ class CW(SimulationSetup):
def __init__(self, cfg: DictConfig):
super().__init__(cfg)

# define offset vector
self.offset_vec = np.ones(cfg.case.dim) * cfg.solver.n_walls * cfg.case.dx

# relaxation configurations
if cfg.case.mode == "rlx" or cfg.case.r0_type == "relaxed":
raise NotImplementedError("Relaxation not implemented for CW.")

def _box_size2D(self):
return np.array([self.special.L_wall, self.special.H_wall]) + 6 * self.case.dx
def _box_size2D(self, n_walls):
sp = self.special
return np.array([sp.L_wall, sp.H_wall]) + 2 * n_walls * self.case.dx

def _box_size3D(self):
dx6 = 6 * self.case.dx
return np.array([self.special.L_wall + dx6, self.special.H_wall + dx6, 0.5])
def _box_size3D(self, n_walls):
sp = self.special
dx2n = 2 * n_walls * self.case.dx
return np.array([sp.L_wall + dx2n, sp.H_wall + dx2n, 1.0 + dx2n])

def _init_pos2D(self, box_size, dx):
dx3 = 3 * self.case.dx
walls = pos_box_2d(self.special.L_wall, self.special.H_wall, dx)
def _init_walls_2d(self, dx, n_walls):
sp = self.special
rw = pos_box_2d(np.array([sp.L_wall, sp.H_wall]), dx, n_walls)
return rw

r_fluid = pos_init_cartesian_2d(np.array([self.special.L, self.special.H]), dx)
r_fluid += dx3 + np.array(self.special.cube_offset)
res = np.concatenate([walls, r_fluid])
return res
def _init_walls_3d(self, dx, n_walls):
sp = self.special
rw = pos_box_3d(np.array([sp.L_wall, sp.H_wall, 1.0]), dx, n_walls, False)
return rw

def _init_pos2D(self, box_size, dx, n_walls):
dxn = n_walls * self.case.dx

# initialize walls
r_w = self._init_walls_2d(dx, n_walls)

# initialize fluid phase
r_f = pos_init_cartesian_2d(np.array([self.special.L, self.special.H]), dx)
r_f += dxn + np.array(self.special.cube_offset)

# set tags
tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)
tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)

r = np.concatenate([r_w, r_f])
tag = np.concatenate([tag_w, tag_f])
return r, tag

def _init_pos3D(self, box_size, dx, n_walls):
dxn = n_walls * self.case.dx

# initialize walls
r_w = self._init_walls_3d(dx, n_walls)

def _tag2D(self, r):
dx3 = 3 * self.case.dx
mask_left = jnp.where(r[:, 0] < dx3, True, False)
mask_bottom = jnp.where(r[:, 1] < dx3, True, False)
mask_right = jnp.where(r[:, 0] > self.special.L_wall + dx3, True, False)
mask_top = jnp.where(r[:, 1] > self.special.H_wall + dx3, True, False)
mask_wall = mask_left + mask_bottom + mask_right + mask_top
# initialize fluid phase
r_f = pos_init_cartesian_3d(np.array([self.special.L, self.special.H, 0.3]), dx)
r_f += dxn + np.array(self.special.cube_offset)

tag = jnp.full(len(r), Tag.FLUID, dtype=int)
tag = jnp.where(mask_wall, Tag.SOLID_WALL, tag)
return tag
# set tags
tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)
tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)

def _tag3D(self, r):
return self._tag2D(r)
r = np.concatenate([r_w, r_f])
tag = np.concatenate([tag_w, tag_f])
return r, tag

def _init_velocity2D(self, r):
res = jnp.array(self.special.u_init)
Expand Down
101 changes: 69 additions & 32 deletions cases/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from omegaconf import DictConfig

from jax_sph.case_setup import SimulationSetup
from jax_sph.utils import Tag, pos_box_2d, pos_init_cartesian_2d
from jax_sph.utils import (
Tag,
pos_box_2d,
pos_box_3d,
pos_init_cartesian_2d,
pos_init_cartesian_3d,
)


class DB(SimulationSetup):
Expand All @@ -24,6 +30,9 @@ def __init__(self, cfg: DictConfig):
# | --------------------------|
# < L_wall >

# define offset vector
self.offset_vec = self._offset_vec()

# relaxation configurations
if self.case.mode == "rlx":
self.special.L_wall = self.special.L
Expand All @@ -33,55 +42,83 @@ def __init__(self, cfg: DictConfig):
if self.case.r0_type == "relaxed":
self._load_only_fluid = True

def _box_size2D(self):
def _box_size2D(self, n_walls):
dx, bo = self.case.dx, self.special.box_offset
return np.array(
[self.special.L_wall + 6 * dx + bo, self.special.H_wall + 6 * dx + bo]
[
self.special.L_wall + 2 * n_walls * dx + bo,
self.special.H_wall + 2 * n_walls * dx + bo,
]
)

def _box_size3D(self):
def _box_size3D(self, n_walls):
dx, bo = self.case.dx, self.box_offset
sp = self.special
return np.array([sp.L_wall + 6 * dx + bo, sp.H_wall + 6 * dx + bo, sp.W])
return np.array(
[sp.L_wall + 2 * n_walls * dx + bo, sp.H_wall + 2 * n_walls * dx + bo, sp.W]
)

def _init_walls_2d(self, dx, n_walls):
sp = self.special
rw = pos_box_2d(np.array([sp.L_wall, sp.H_wall]), dx, n_walls)
return rw

def _init_walls_3d(self, dx, n_walls):
sp = self.special
rw = pos_box_3d(np.array([sp.L_wall, sp.H_wall, 1.0]), dx, n_walls)
return rw

def _init_pos2D(self, box_size, dx):
def _init_pos2D(self, box_size, dx, n_walls):
sp = self.special

# initialize fluid phase
if self.case.r0_type == "cartesian":
r_fluid = 3 * dx + pos_init_cartesian_2d(np.array([sp.L, sp.H]), dx)
r_f = n_walls * dx + pos_init_cartesian_2d(np.array([sp.L, sp.H]), dx)
else:
r_fluid = self._get_relaxed_r0(None, dx)
r_f = self._get_relaxed_r0(None, dx)

walls = pos_box_2d(sp.L_wall, sp.H_wall, dx)
res = np.concatenate([walls, r_fluid])
return res
# initialize walls
r_w = self._init_walls_2d(dx, n_walls)

def _init_pos3D(self, box_size, dx):
# cartesian coordinates in z
Lz = box_size[2]
zs = np.arange(0, Lz, dx) + 0.5 * dx
# set tags
tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)
tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)

# extend 2D points to 3D
xy = self._init_pos2D(box_size, dx)
xy_ext = np.hstack([xy, np.ones((len(xy), 1))])
r = np.concatenate([r_w, r_f])
tag = np.concatenate([tag_w, tag_f])
return r, tag

r_xyz = np.vstack([xy_ext * [1, 1, z] for z in zs])
return r_xyz
def _init_pos3D(self, box_size, dx, n_walls):
# TODO: not validated yet
sp = self.special

def _tag2D(self, r):
dx3 = 3 * self.case.dx
mask_left = jnp.where(r[:, 0] < dx3, True, False)
mask_bottom = jnp.where(r[:, 1] < dx3, True, False)
mask_right = jnp.where(r[:, 0] > self.special.L_wall + dx3, True, False)
mask_top = jnp.where(r[:, 1] > self.special.H_wall + dx3, True, False)
# initialize fluid phase
if self.case.r0_type == "cartesian":
r_f = np.array([1.0, 1.0, 0.0]) * n_walls * dx + pos_init_cartesian_3d(
np.array([sp.L, sp.H, 1.0]), dx
)
else:
r_f = self._get_relaxed_r0(None, dx)

mask_wall = mask_left + mask_bottom + mask_right + mask_top
# initialize walls
r_w = self._init_walls_3d(dx, n_walls)

tag = jnp.full(len(r), Tag.FLUID, dtype=int)
tag = jnp.where(mask_wall, Tag.SOLID_WALL, tag)
return tag
# set tags
tag_f = jnp.full(len(r_f), Tag.FLUID, dtype=int)
tag_w = jnp.full(len(r_w), Tag.SOLID_WALL, dtype=int)

def _tag3D(self, r):
return self._tag2D(r)
r = np.concatenate([r_w, r_f])
tag = np.concatenate([tag_w, tag_f])

return r, tag

def _offset_vec(self):
dim = self.cfg.case.dim
if dim == 2:
res = np.ones(dim) * self.cfg.solver.n_walls * self.cfg.case.dx
elif dim == 3:
res = np.array([1.0, 1.0, 0.0]) * self.cfg.solver.n_walls * self.cfg.case.dx
return res

def _init_velocity2D(self, r):
return jnp.zeros_like(r)
Expand Down
Loading

0 comments on commit 8952d7e

Please sign in to comment.