Skip to content

Commit

Permalink
Merge branch 'main' into couette-flow
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasErbesdobler committed Jul 18, 2024
2 parents 6626abd + 2623413 commit 7ddde8f
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 31 deletions.
13 changes: 10 additions & 3 deletions jax_sph/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
### solver
cfg.solver = OmegaConf.create({})

# Solver name. One of "SPH" (standard SPH) or "RIE" (Riemann SPH)
# Solver name. One of "SPH" (standard SPH) or "RIE" (Riemann SPH) or
# "DELTA" (Delta SPH)
cfg.solver.name = "SPH"
# Transport velocity inclusion factor [0,...,1]
cfg.solver.tvf = 0.0
Expand Down Expand Up @@ -92,6 +93,10 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
cfg.solver.heat_conduction = False
# Whether to apply boundaty conditions
cfg.solver.is_bc_trick = False # new
# Delta SPH density diffusion weighting factor
cfg.solver.diff_delta = 0.1
# Delta SPH acceleratin diffusion weighting factor
cfg.solver.diff_alpha = 0.01

### kernel
cfg.kernel = OmegaConf.create({})
Expand All @@ -105,8 +110,10 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
# "GK" (gaussian kernel)
# "SGK" (super gaussian kernel)
cfg.kernel.name = "QSK"
# Smoothing length factor
cfg.kernel.h_factor = 1.0 # new. Should default to 1.3 WC2K and 1.0 QSK
# Smoothing length factor, should default to
# 1.3 for WC2K, WC4K, WC6K
# 1.0 for CSK, QSK, GK, SGK
cfg.kernel.h_factor = 1.0

### equation of state
cfg.eos = OmegaConf.create({})
Expand Down
3 changes: 3 additions & 0 deletions jax_sph/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def simulate(cfg: DictConfig):
cfg.solver.dt,
cfg.case.c_ref,
cfg.solver.eta_limiter,
cfg.solver.diff_delta,
cfg.solver.diff_alpha,
cfg.solver.name,
cfg.kernel.name,
cfg.kernel.h_factor,
cfg.solver.is_bc_trick,
cfg.solver.density_evolution,
cfg.solver.artificial_alpha,
Expand Down
236 changes: 210 additions & 26 deletions jax_sph/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,81 @@ def rho_evol_fn(rho, mass, u, grad_w_dist, i_s, j_s, dt, N, **kwargs):
return rho, drhodt


def rho_evol_fn_delta_wrapper(delta, support, c_ref, kernel_fn):
"""Density evolution according to Marrone et al. 2011."""

def rho_evol_fn_delta(
rho, mass, r_ij, d_ij, u, i_s, j_s, dt, N, fluidmask_j, **kwargs
):
# compute common quantities
def quantities_fn(r_ab, d_ab, m_a, m_b, rho_a, rho_b):
e_ab = r_ab / (d_ab + EPS)
kernel_grad = kernel_fn.grad_w(d_ab) * (e_ab)
V_a = m_a / rho_a
V_b = m_b / rho_b
return kernel_grad, V_a, V_b

kernel_grad, V_i, V_j = vmap(quantities_fn)(
r_ij, d_ij, mass[i_s], mass[j_s], rho[i_s], rho[j_s]
)

def L_fn(r_ab, kernel_grad, V_b):
return jnp.tensordot(-r_ab, kernel_grad * V_b, axes=0)

temp = vmap(L_fn)(r_ij, kernel_grad, V_j)
L_mati = jnp.linalg.inv(ops.segment_sum(temp, i_s, N))
temp = vmap(L_fn)(-r_ij, kernel_grad, V_i)
L_matj = jnp.linalg.inv(ops.segment_sum(temp, j_s, N))

def rho_grad_fn(rho_a, rho_b, L_a, kernel_grad, V_b):
return (rho_b - rho_a) * jnp.dot(L_a, kernel_grad * V_b)

temp = vmap(rho_grad_fn)(rho[i_s], rho[j_s], L_mati[i_s], kernel_grad, V_j)
rho_grad_term_i = ops.segment_sum(temp, i_s, N)
temp = vmap(rho_grad_fn)(rho[j_s], rho[i_s], L_matj[j_s], -kernel_grad, V_i)
rho_grad_term_j = ops.segment_sum(temp, i_s, N)

def rho_diff_fn(
rho_i,
rho_j,
r_ij,
d_ij,
rho_grad_term_i,
rho_grad_term_j,
fluidmask_j,
kernel_grad,
V_j,
):
rho_term = 2 * (rho_j - rho_i) * (-r_ij) / (d_ij + EPS) ** 2
psi_ij = rho_term - rho_grad_term_i - rho_grad_term_j
return jnp.dot(psi_ij, kernel_grad) * V_j * fluidmask_j

temp = vmap(rho_diff_fn)(
rho[i_s],
rho[j_s],
r_ij,
d_ij,
rho_grad_term_i[i_s],
rho_grad_term_j[j_s],
fluidmask_j,
kernel_grad,
V_j,
)
diff_term = ops.segment_sum(temp, i_s, N)

def cont_eq(u_i, u_j, kernel_grad, V_j):
return jnp.dot(u_i - u_j, kernel_grad) * V_j

temp = vmap(cont_eq)(u[i_s], u[j_s], kernel_grad, V_j)
drhodt = (
rho * ops.segment_sum(temp, i_s, N) + c_ref * delta * support * diff_term
)
rho = rho + dt * drhodt
return rho, drhodt

return rho_evol_fn_delta


def rho_evol_riemann_fn_wrapper(kernel_fn, eos, c_ref):
"""Density evolution according to Zhang et al. 2017."""

Expand Down Expand Up @@ -181,6 +256,63 @@ def acceleration_standard_fn(
return acceleration_standard_fn


def acceleration_delta_fn_wrapper(kernel_fn, alpha, support, c_ref, rho_ref):
"""Standard SPH acceleration according to Adami et al. 2012., and
Delta SPH relaxation according to Marrone et al. 2011."""

def acceleration_delta_fn(
r_ij,
d_ij,
rho_i,
rho_j,
u_i,
u_j,
v_i,
v_j,
m_i,
m_j,
eta_i,
eta_j,
p_i,
p_j,
fluidmask_j,
):
# (Eq. 6) - inter-particle-averaged shear viscosity (harmonic mean)
eta_ij = 2 * eta_i * eta_j / (eta_i + eta_j + EPS)
# (Eq. 7) - density-weighted pressure (weighted arithmetic mean)
p_ij = (rho_j * p_i + rho_i * p_j) / (rho_i + rho_j)

# compute the common prefactor `c`
weighted_volume = ((m_i / rho_i) ** 2 + (m_j / rho_j) ** 2) / m_i
kernel_grad = kernel_fn.grad_w(d_ij)
c = weighted_volume * kernel_grad / (d_ij + EPS)

# (Eq. 8): \boldsymbol{e}_{ij} is computed as r_ij/d_ij here.
_A = (tvf_stress_fn(rho_i, u_i, v_i) + tvf_stress_fn(rho_j, u_j, v_j)) / 2
_u_ij = u_i - u_j
a_eq_8 = c * (-p_ij * r_ij + jnp.dot(_A, r_ij) + eta_ij * _u_ij)

e_ij = r_ij / (d_ij + EPS)
kernel_grad = kernel_grad * (e_ij)
V_j = m_j / rho_j
pi_ij = jnp.dot((u_j - u_i), (-r_ij)) / (d_ij + EPS) ** 2
acceleration_diff = (
V_j
* pi_ij
* kernel_grad
* alpha
* support
* c_ref
* rho_ref
/ rho_i
* fluidmask_j
)

return a_eq_8 + acceleration_diff

return acceleration_delta_fn


def acceleration_riemann_fn_wrapper(kernel_fn, eos, beta_fn, eta_limiter):
"""Riemann solver acceleration according to Zhang et al. 2017."""

Expand Down Expand Up @@ -491,8 +623,11 @@ def __init__(
dt: float,
c_ref: float,
eta_limiter: float = 3,
diff_delta: float = 0.02,
diff_alpha: float = 0.1,
solver: str = "SPH",
kernel: str = "QSK",
h_fac: float = 1.0,
is_bc_trick: bool = False,
is_rho_evol: bool = False,
artificial_alpha: float = 0.0,
Expand All @@ -508,25 +643,27 @@ def __init__(
self.is_rho_renorm = is_rho_renorm
self.dt = dt
self.eos = eos
self.c_ref = c_ref
self.diff_delta = diff_delta
self.diff_alpha = diff_alpha
self.artificial_alpha = artificial_alpha
self.is_heat_conduction = is_heat_conduction

_beta_fn = limiter_fn_wrapper(eta_limiter, c_ref)
match kernel:
case "CSK":
self._kernel_fn = CubicKernel(h=dx, dim=dim)
case "QSK":
self._kernel_fn = QuinticKernel(h=dx, dim=dim)
case "WC2K":
self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim)
case "WC4K":
self._kernel_fn = WendlandC4Kernel(h=1.3 * dx, dim=dim)
case "WC6K":
self._kernel_fn = WendlandC6Kernel(h=1.3 * dx, dim=dim)
case "GK":
self._kernel_fn = GaussianKernel(h=dx, dim=dim)
case "SGK":
self._kernel_fn = SuperGaussianKernel(h=dx, dim=dim)
if kernel == "CSK":
self._kernel_fn = CubicKernel(h=h_fac * dx, dim=dim)
elif kernel == "QSK":
self._kernel_fn = QuinticKernel(h=h_fac * dx, dim=dim)
elif kernel == "WC2K":
self._kernel_fn = WendlandC2Kernel(h=h_fac * dx, dim=dim)
elif kernel == "WC4K":
self._kernel_fn = WendlandC4Kernel(h=h_fac * dx, dim=dim)
elif kernel == "WC6K":
self._kernel_fn = WendlandC6Kernel(h=h_fac * dx, dim=dim)
elif kernel == "GK":
self._kernel_fn = GaussianKernel(h=h_fac * dx, dim=dim)
elif kernel == "SGK":
self._kernel_fn = SuperGaussianKernel(h=h_fac * dx, dim=dim)

self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos)
(
Expand All @@ -535,17 +672,31 @@ def __init__(
self._heat_bc,
) = gwbc_fn_riemann_wrapper(is_free_slip, is_heat_conduction)
self._acceleration_tvf_fn = acceleration_tvf_fn_wrapper(self._kernel_fn)
self._acceleration_riemann_fn = acceleration_riemann_fn_wrapper(
self._kernel_fn, eos, _beta_fn, eta_limiter
)
self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn)

if solver == "SPH":
self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn)
self._rho_evol_fn = rho_evol_fn
elif solver == "RIE":
self._acceleration_fn = acceleration_riemann_fn_wrapper(
self._kernel_fn, eos, _beta_fn, eta_limiter
)
self._rho_evol_fn = rho_evol_riemann_fn_wrapper(self._kernel_fn, eos, c_ref)
elif solver == "DELTA":
self._acceleration_fn = acceleration_delta_fn_wrapper(
self._kernel_fn,
self.diff_alpha,
h_fac * dx,
self.c_ref,
self.eos.rho_ref,
)
self._rho_evol_fn = rho_evol_fn_delta_wrapper(
self.diff_delta, h_fac * dx, self.c_ref, self._kernel_fn
)

self._artificial_viscosity_fn = artificial_viscosity_fn_wrapper(
dx, artificial_alpha
)
self._wall_phi_vec = wall_phi_vec_wrapper(self._kernel_fn)
self._rho_evol_riemann_fn = rho_evol_riemann_fn_wrapper(
self._kernel_fn, eos, c_ref
)
self._temperature_derivative = temperature_derivative_wrapper(self._kernel_fn)

def forward_wrapper(self):
Expand Down Expand Up @@ -597,14 +748,29 @@ def forward(state, neighbors):
# update evolution

if self.is_rho_evol and (self.solver == "SPH"):
rho, drhodt = rho_evol_fn(
rho, drhodt = self._rho_evol_fn(
rho, mass, u, grad_w_dist, i_s, j_s, self.dt, N
)

if self.is_rho_renorm:
rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N)
elif self.is_rho_evol and (self.solver == "DELTA"):
rho, drhodt = self._rho_evol_fn(
rho,
mass,
dr_i_j,
dist,
u,
i_s,
j_s,
self.dt,
N,
fluid_mask[j_s],
)
if self.is_rho_renorm:
rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N)
elif self.is_rho_evol and (self.solver == "RIE"):
temp = vmap(self._rho_evol_riemann_fn)(
temp = vmap(self._rho_evol_fn)(
e_s,
rho[i_s],
rho[j_s],
Expand Down Expand Up @@ -637,7 +803,7 @@ def forward(state, neighbors):

##### Apply BC trick

if self.is_bc_trick and (self.solver == "SPH"):
if self.is_bc_trick and (self.solver == "SPH" or self.solver == "DELTA"):
p, rho, u, v, temperature = self._gwbc_fn(
temperature,
rho,
Expand Down Expand Up @@ -702,8 +868,26 @@ def forward(state, neighbors):
p[i_s],
p[j_s],
)
elif self.solver == "DELTA":
out = vmap(self._acceleration_fn)(
dr_i_j,
dist,
rho[i_s],
rho[j_s],
u[i_s],
u[j_s],
v[i_s],
v[j_s],
mass[i_s],
mass[j_s],
eta[i_s],
eta[j_s],
p[i_s],
p[j_s],
fluid_mask[j_s],
)
elif self.solver == "RIE":
out = vmap(self._acceleration_riemann_fn)(
out = vmap(self._acceleration_fn)(
e_s,
dr_i_j,
dist,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jax-sph"
version = "0.0.1"
version = "0.1.0"
description = "JAX-SPH: Smoothed Particle Hydrodynamics in JAX"
authors = ["Artur Toshev <[email protected]>",]
maintainers = ["Artur Toshev <[email protected]>",]
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pf2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def get_solution(data_path, t_dimless, y_axis):
return solutions


@pytest.mark.parametrize("tvf, solver", [(0.0, "SPH"), (1.0, "SPH"), (0.0, "RIE")])
@pytest.mark.parametrize(
"tvf, solver", [(0.0, "SPH"), (1.0, "SPH"), (0.0, "RIE"), (0.0, "DELTA")]
)
def test_pf2d(tvf, solver, tmp_path, setup_simulation):
"""Test whether the poiseuille flow simulation matches the analytical solution"""
y_axis, t_dimless, ref_solutions = setup_simulation
Expand Down
5 changes: 5 additions & 0 deletions validation/db2d.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#!/bin/bash

####### 2D Dam break
# Riemann SPH
python main.py config=cases/db.yaml case.u_ref=2 case.viscosity=0.0 solver.name=RIE solver.t_end=7.5 solver.dt=0.0002 solver.free_slip=True solver.artificial_alpha=0.0 solver.eta_limiter=3 io.write_every=50 io.data_path=data_valid/db2d_Riemann/

# Delta SPH
python main.py config=cases/db.yaml case.u_ref=1.95 case.special.H_wall=5.366 case.viscosity=0.0 eos.gamma=7.0 solver.name=DELTA solver.t_end=7.5 solver.dt=0.0002 solver.free_slip=True solver.artificial_alpha=0.0 io.write_every=50 io.data_path=data_valid/db2d_Delta/

# Run validation script
python validation/validate.py --case=2D_DB --src_dir=data_valid/db2d_Riemann/
python validation/validate.py --case=2D_DB --src_dir=data_valid/db2d_Delta/
2 changes: 2 additions & 0 deletions validation/pf2d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
# Generate data
python main.py config=cases/pf.yaml solver.tvf=1.0 io.data_path=data_valid/pf2d_tvf/
python main.py config=cases/pf.yaml solver.tvf=0.0 io.data_path=data_valid/pf2d_notvf/
python main.py config=cases/pf.yaml solver.tvf=0.0 solver.name=DELTA io.data_path=data_valid/pf2d_delta/
python main.py config=cases/pf.yaml solver.tvf=0.0 solver.name=RIE solver.density_evolution=True io.data_path=data_valid/pf2d_Rie/

# Run validation script
python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_tvf/
python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_notvf/
python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_delta/
python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_Rie/

0 comments on commit 7ddde8f

Please sign in to comment.