Skip to content

Commit

Permalink
review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasErbesdobler committed Jul 18, 2024
1 parent 2d3b4cd commit 0713293
Showing 1 changed file with 42 additions and 40 deletions.
82 changes: 42 additions & 40 deletions jax_sph/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def L_fn(r_ab, kernel_grad, V_b):

temp = vmap(L_fn)(r_ij, kernel_grad, V_j)
L_mati = jnp.linalg.inv(ops.segment_sum(temp, i_s, N))
# TODO: check whether this is the same
temp = vmap(L_fn)(-r_ij, kernel_grad, V_i)
L_matj = jnp.linalg.inv(ops.segment_sum(temp, j_s, N))

Expand Down Expand Up @@ -258,8 +257,8 @@ def acceleration_standard_fn(


def acceleration_delta_fn_wrapper(kernel_fn, alpha, support, c_ref, rho_ref):
"""Standard SPH acceleration according to Adami et al. 2012."""
"""Delta SPH acceleration according to Marrone et al. 2011."""
"""Standard SPH acceleration according to Adami et al. 2012., and
Delta SPH relaxation according to Marrone et al. 2011."""

def acceleration_delta_fn(
r_ij,
Expand Down Expand Up @@ -651,21 +650,20 @@ def __init__(
self.is_heat_conduction = is_heat_conduction

_beta_fn = limiter_fn_wrapper(eta_limiter, c_ref)
match kernel:
case "CSK":
self._kernel_fn = CubicKernel(h=h_fac * dx, dim=dim)
case "QSK":
self._kernel_fn = QuinticKernel(h=h_fac * dx, dim=dim)
case "WC2K":
self._kernel_fn = WendlandC2Kernel(h=h_fac * dx, dim=dim)
case "WC4K":
self._kernel_fn = WendlandC4Kernel(h=h_fac * dx, dim=dim)
case "WC6K":
self._kernel_fn = WendlandC6Kernel(h=h_fac * dx, dim=dim)
case "GK":
self._kernel_fn = GaussianKernel(h=h_fac * dx, dim=dim)
case "SGK":
self._kernel_fn = SuperGaussianKernel(h=h_fac * dx, dim=dim)
if kernel == "CSK":
self._kernel_fn = CubicKernel(h=h_fac * dx, dim=dim)
elif kernel == "QSK":
self._kernel_fn = QuinticKernel(h=h_fac * dx, dim=dim)
elif kernel == "WC2K":
self._kernel_fn = WendlandC2Kernel(h=h_fac * dx, dim=dim)
elif kernel == "WC4K":
self._kernel_fn = WendlandC4Kernel(h=h_fac * dx, dim=dim)
elif kernel == "WC6K":
self._kernel_fn = WendlandC6Kernel(h=h_fac * dx, dim=dim)
elif kernel == "GK":
self._kernel_fn = GaussianKernel(h=h_fac * dx, dim=dim)
elif kernel == "SGK":
self._kernel_fn = SuperGaussianKernel(h=h_fac * dx, dim=dim)

self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos)
(
Expand All @@ -674,27 +672,31 @@ def __init__(
self._heat_bc,
) = gwbc_fn_riemann_wrapper(is_free_slip, is_heat_conduction)
self._acceleration_tvf_fn = acceleration_tvf_fn_wrapper(self._kernel_fn)
self._acceleration_riemann_fn = acceleration_riemann_fn_wrapper(
self._kernel_fn, eos, _beta_fn, eta_limiter
)
self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn)
self._acceleration_delta_fn = acceleration_delta_fn_wrapper(
self._kernel_fn,
self.diff_alpha,
h_fac * dx,
self.c_ref,
self.eos.rho_ref,
)

if solver == "SPH":
self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn)
self._rho_evol_fn = rho_evol_fn
elif solver == "RIE":
self._acceleration_fn = acceleration_riemann_fn_wrapper(
self._kernel_fn, eos, _beta_fn, eta_limiter
)
self._rho_evol_fn = rho_evol_riemann_fn_wrapper(self._kernel_fn, eos, c_ref)
elif solver == "DELTA":
self._acceleration_fn = acceleration_delta_fn_wrapper(
self._kernel_fn,
self.diff_alpha,
h_fac * dx,
self.c_ref,
self.eos.rho_ref,
)
self._rho_evol_fn = rho_evol_fn_delta_wrapper(
self.diff_delta, h_fac * dx, self.c_ref, self._kernel_fn
)

self._artificial_viscosity_fn = artificial_viscosity_fn_wrapper(
dx, artificial_alpha
)
self._wall_phi_vec = wall_phi_vec_wrapper(self._kernel_fn)
self._rho_evol_riemann_fn = rho_evol_riemann_fn_wrapper(
self._kernel_fn, eos, c_ref
)
self._rho_evol_detla_fn = rho_evol_fn_delta_wrapper(
self.diff_delta, h_fac * dx, self.c_ref, self._kernel_fn
)
self._temperature_derivative = temperature_derivative_wrapper(self._kernel_fn)

def forward_wrapper(self):
Expand Down Expand Up @@ -746,14 +748,14 @@ def forward(state, neighbors):
# update evolution

if self.is_rho_evol and (self.solver == "SPH"):
rho, drhodt = rho_evol_fn(
rho, drhodt = self._rho_evol_fn(
rho, mass, u, grad_w_dist, i_s, j_s, self.dt, N
)

if self.is_rho_renorm:
rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N)
elif self.is_rho_evol and (self.solver == "DELTA"):
rho, drhodt = self._rho_evol_detla_fn(
rho, drhodt = self._rho_evol_fn(
rho,
mass,
dr_i_j,
Expand All @@ -768,7 +770,7 @@ def forward(state, neighbors):
if self.is_rho_renorm:
rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N)
elif self.is_rho_evol and (self.solver == "RIE"):
temp = vmap(self._rho_evol_riemann_fn)(
temp = vmap(self._rho_evol_fn)(
e_s,
rho[i_s],
rho[j_s],
Expand Down Expand Up @@ -867,7 +869,7 @@ def forward(state, neighbors):
p[j_s],
)
elif self.solver == "DELTA":
out = vmap(self._acceleration_delta_fn)(
out = vmap(self._acceleration_fn)(
dr_i_j,
dist,
rho[i_s],
Expand All @@ -885,7 +887,7 @@ def forward(state, neighbors):
fluid_mask[j_s],
)
elif self.solver == "RIE":
out = vmap(self._acceleration_riemann_fn)(
out = vmap(self._acceleration_fn)(
e_s,
dr_i_j,
dist,
Expand Down

0 comments on commit 0713293

Please sign in to comment.