From 2d3b4cd9c5e5cfd3dcccbcc77cc63ba698ff12a6 Mon Sep 17 00:00:00 2001 From: Jonas Erbesdobler Date: Thu, 27 Jun 2024 21:42:59 +0200 Subject: [PATCH 1/3] working delta-sph and fix usage of "h_factor" --- jax_sph/case_setup.py | 2 +- jax_sph/defaults.py | 13 ++- jax_sph/simulate.py | 3 + jax_sph/solver.py | 198 ++++++++++++++++++++++++++++++++++++++++-- tests/test_pf2d.py | 4 +- validation/db2d.sh | 5 ++ validation/pf2d.sh | 2 + 7 files changed, 214 insertions(+), 13 deletions(-) diff --git a/jax_sph/case_setup.py b/jax_sph/case_setup.py index 46d2d01..a05a563 100644 --- a/jax_sph/case_setup.py +++ b/jax_sph/case_setup.py @@ -187,7 +187,7 @@ def initialize(self): if k not in cfg.case.state0_keys: continue assert k in _state, ValueError(f"Key {k} not found in state0 file.") - mask, _mask = state["tag"]==Tag.FLUID, _state["tag"]==Tag.FLUID + mask, _mask = state["tag"] == Tag.FLUID, _state["tag"] == Tag.FLUID assert state[k][mask].shape == _state[k][_mask].shape, ValueError( f"Shape mismatch for key {k} in state0 file." ) diff --git a/jax_sph/defaults.py b/jax_sph/defaults.py index 4574c2b..121c693 100644 --- a/jax_sph/defaults.py +++ b/jax_sph/defaults.py @@ -64,7 +64,8 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: ### solver cfg.solver = OmegaConf.create({}) - # Solver name. One of "SPH" (standard SPH) or "RIE" (Riemann SPH) + # Solver name. One of "SPH" (standard SPH) or "RIE" (Riemann SPH) or + # "DELTA" (Delta SPH) cfg.solver.name = "SPH" # Transport velocity inclusion factor [0,...,1] cfg.solver.tvf = 0.0 @@ -92,6 +93,10 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: cfg.solver.heat_conduction = False # Whether to apply boundaty conditions cfg.solver.is_bc_trick = False # new + # Delta SPH density diffusion weighting factor + cfg.solver.diff_delta = 0.1 + # Delta SPH acceleratin diffusion weighting factor + cfg.solver.diff_alpha = 0.01 ### kernel cfg.kernel = OmegaConf.create({}) @@ -105,8 +110,10 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: # "GK" (gaussian kernel) # "SGK" (super gaussian kernel) cfg.kernel.name = "QSK" - # Smoothing length factor - cfg.kernel.h_factor = 1.0 # new. Should default to 1.3 WC2K and 1.0 QSK + # Smoothing length factor, should default to + # 1.3 for WC2K, WC4K, WC6K + # 1.0 for CSK, QSK, GK, SGK + cfg.kernel.h_factor = 1.0 ### equation of state cfg.eos = OmegaConf.create({}) diff --git a/jax_sph/simulate.py b/jax_sph/simulate.py index 5395666..3f648f4 100644 --- a/jax_sph/simulate.py +++ b/jax_sph/simulate.py @@ -55,8 +55,11 @@ def simulate(cfg: DictConfig): cfg.solver.dt, cfg.case.c_ref, cfg.solver.eta_limiter, + cfg.solver.diff_delta, + cfg.solver.diff_alpha, cfg.solver.name, cfg.kernel.name, + cfg.kernel.h_factor, cfg.solver.is_bc_trick, cfg.solver.density_evolution, cfg.solver.artificial_alpha, diff --git a/jax_sph/solver.py b/jax_sph/solver.py index ba434f1..3a892d6 100644 --- a/jax_sph/solver.py +++ b/jax_sph/solver.py @@ -30,6 +30,82 @@ def rho_evol_fn(rho, mass, u, grad_w_dist, i_s, j_s, dt, N, **kwargs): return rho, drhodt +def rho_evol_fn_delta_wrapper(delta, support, c_ref, kernel_fn): + """Density evolution according to Marrone et al. 2011.""" + + def rho_evol_fn_delta( + rho, mass, r_ij, d_ij, u, i_s, j_s, dt, N, fluidmask_j, **kwargs + ): + # compute common quantities + def quantities_fn(r_ab, d_ab, m_a, m_b, rho_a, rho_b): + e_ab = r_ab / (d_ab + EPS) + kernel_grad = kernel_fn.grad_w(d_ab) * (e_ab) + V_a = m_a / rho_a + V_b = m_b / rho_b + return kernel_grad, V_a, V_b + + kernel_grad, V_i, V_j = vmap(quantities_fn)( + r_ij, d_ij, mass[i_s], mass[j_s], rho[i_s], rho[j_s] + ) + + def L_fn(r_ab, kernel_grad, V_b): + return jnp.tensordot(-r_ab, kernel_grad * V_b, axes=0) + + temp = vmap(L_fn)(r_ij, kernel_grad, V_j) + L_mati = jnp.linalg.inv(ops.segment_sum(temp, i_s, N)) + # TODO: check whether this is the same + temp = vmap(L_fn)(-r_ij, kernel_grad, V_i) + L_matj = jnp.linalg.inv(ops.segment_sum(temp, j_s, N)) + + def rho_grad_fn(rho_a, rho_b, L_a, kernel_grad, V_b): + return (rho_b - rho_a) * jnp.dot(L_a, kernel_grad * V_b) + + temp = vmap(rho_grad_fn)(rho[i_s], rho[j_s], L_mati[i_s], kernel_grad, V_j) + rho_grad_term_i = ops.segment_sum(temp, i_s, N) + temp = vmap(rho_grad_fn)(rho[j_s], rho[i_s], L_matj[j_s], -kernel_grad, V_i) + rho_grad_term_j = ops.segment_sum(temp, i_s, N) + + def rho_diff_fn( + rho_i, + rho_j, + r_ij, + d_ij, + rho_grad_term_i, + rho_grad_term_j, + fluidmask_j, + kernel_grad, + V_j, + ): + rho_term = 2 * (rho_j - rho_i) * (-r_ij) / (d_ij + EPS) ** 2 + psi_ij = rho_term - rho_grad_term_i - rho_grad_term_j + return jnp.dot(psi_ij, kernel_grad) * V_j * fluidmask_j + + temp = vmap(rho_diff_fn)( + rho[i_s], + rho[j_s], + r_ij, + d_ij, + rho_grad_term_i[i_s], + rho_grad_term_j[j_s], + fluidmask_j, + kernel_grad, + V_j, + ) + diff_term = ops.segment_sum(temp, i_s, N) + + def cont_eq(u_i, u_j, kernel_grad, V_j): + return jnp.dot(u_i - u_j, kernel_grad) * V_j + + temp = vmap(cont_eq)(u[i_s], u[j_s], kernel_grad, V_j) + drhodt = ( + rho * ops.segment_sum(temp, i_s, N) + c_ref * delta * support * diff_term + ) + rho = rho + dt * drhodt + return rho, drhodt + + return rho_evol_fn_delta + + def rho_evol_riemann_fn_wrapper(kernel_fn, eos, c_ref): """Density evolution according to Zhang et al. 2017.""" @@ -181,6 +257,63 @@ def acceleration_standard_fn( return acceleration_standard_fn +def acceleration_delta_fn_wrapper(kernel_fn, alpha, support, c_ref, rho_ref): + """Standard SPH acceleration according to Adami et al. 2012.""" + """Delta SPH acceleration according to Marrone et al. 2011.""" + + def acceleration_delta_fn( + r_ij, + d_ij, + rho_i, + rho_j, + u_i, + u_j, + v_i, + v_j, + m_i, + m_j, + eta_i, + eta_j, + p_i, + p_j, + fluidmask_j, + ): + # (Eq. 6) - inter-particle-averaged shear viscosity (harmonic mean) + eta_ij = 2 * eta_i * eta_j / (eta_i + eta_j + EPS) + # (Eq. 7) - density-weighted pressure (weighted arithmetic mean) + p_ij = (rho_j * p_i + rho_i * p_j) / (rho_i + rho_j) + + # compute the common prefactor `c` + weighted_volume = ((m_i / rho_i) ** 2 + (m_j / rho_j) ** 2) / m_i + kernel_grad = kernel_fn.grad_w(d_ij) + c = weighted_volume * kernel_grad / (d_ij + EPS) + + # (Eq. 8): \boldsymbol{e}_{ij} is computed as r_ij/d_ij here. + _A = (tvf_stress_fn(rho_i, u_i, v_i) + tvf_stress_fn(rho_j, u_j, v_j)) / 2 + _u_ij = u_i - u_j + a_eq_8 = c * (-p_ij * r_ij + jnp.dot(_A, r_ij) + eta_ij * _u_ij) + + e_ij = r_ij / (d_ij + EPS) + kernel_grad = kernel_grad * (e_ij) + V_j = m_j / rho_j + pi_ij = jnp.dot((u_j - u_i), (-r_ij)) / (d_ij + EPS) ** 2 + acceleration_diff = ( + V_j + * pi_ij + * kernel_grad + * alpha + * support + * c_ref + * rho_ref + / rho_i + * fluidmask_j + ) + + return a_eq_8 + acceleration_diff + + return acceleration_delta_fn + + def acceleration_riemann_fn_wrapper(kernel_fn, eos, beta_fn, eta_limiter): """Riemann solver acceleration according to Zhang et al. 2017.""" @@ -491,8 +624,11 @@ def __init__( dt: float, c_ref: float, eta_limiter: float = 3, + diff_delta: float = 0.02, + diff_alpha: float = 0.1, solver: str = "SPH", kernel: str = "QSK", + h_fac: float = 1.0, is_bc_trick: bool = False, is_rho_evol: bool = False, artificial_alpha: float = 0.0, @@ -508,25 +644,28 @@ def __init__( self.is_rho_renorm = is_rho_renorm self.dt = dt self.eos = eos + self.c_ref = c_ref + self.diff_delta = diff_delta + self.diff_alpha = diff_alpha self.artificial_alpha = artificial_alpha self.is_heat_conduction = is_heat_conduction _beta_fn = limiter_fn_wrapper(eta_limiter, c_ref) match kernel: case "CSK": - self._kernel_fn = CubicKernel(h=dx, dim=dim) + self._kernel_fn = CubicKernel(h=h_fac * dx, dim=dim) case "QSK": - self._kernel_fn = QuinticKernel(h=dx, dim=dim) + self._kernel_fn = QuinticKernel(h=h_fac * dx, dim=dim) case "WC2K": - self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim) + self._kernel_fn = WendlandC2Kernel(h=h_fac * dx, dim=dim) case "WC4K": - self._kernel_fn = WendlandC4Kernel(h=1.3 * dx, dim=dim) + self._kernel_fn = WendlandC4Kernel(h=h_fac * dx, dim=dim) case "WC6K": - self._kernel_fn = WendlandC6Kernel(h=1.3 * dx, dim=dim) + self._kernel_fn = WendlandC6Kernel(h=h_fac * dx, dim=dim) case "GK": - self._kernel_fn = GaussianKernel(h=dx, dim=dim) + self._kernel_fn = GaussianKernel(h=h_fac * dx, dim=dim) case "SGK": - self._kernel_fn = SuperGaussianKernel(h=dx, dim=dim) + self._kernel_fn = SuperGaussianKernel(h=h_fac * dx, dim=dim) self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos) ( @@ -539,6 +678,13 @@ def __init__( self._kernel_fn, eos, _beta_fn, eta_limiter ) self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn) + self._acceleration_delta_fn = acceleration_delta_fn_wrapper( + self._kernel_fn, + self.diff_alpha, + h_fac * dx, + self.c_ref, + self.eos.rho_ref, + ) self._artificial_viscosity_fn = artificial_viscosity_fn_wrapper( dx, artificial_alpha ) @@ -546,6 +692,9 @@ def __init__( self._rho_evol_riemann_fn = rho_evol_riemann_fn_wrapper( self._kernel_fn, eos, c_ref ) + self._rho_evol_detla_fn = rho_evol_fn_delta_wrapper( + self.diff_delta, h_fac * dx, self.c_ref, self._kernel_fn + ) self._temperature_derivative = temperature_derivative_wrapper(self._kernel_fn) def forward_wrapper(self): @@ -601,6 +750,21 @@ def forward(state, neighbors): rho, mass, u, grad_w_dist, i_s, j_s, self.dt, N ) + if self.is_rho_renorm: + rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N) + elif self.is_rho_evol and (self.solver == "DELTA"): + rho, drhodt = self._rho_evol_detla_fn( + rho, + mass, + dr_i_j, + dist, + u, + i_s, + j_s, + self.dt, + N, + fluid_mask[j_s], + ) if self.is_rho_renorm: rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N) elif self.is_rho_evol and (self.solver == "RIE"): @@ -637,7 +801,7 @@ def forward(state, neighbors): ##### Apply BC trick - if self.is_bc_trick and (self.solver == "SPH"): + if self.is_bc_trick and (self.solver == "SPH" or self.solver == "DELTA"): p, rho, u, v, temperature = self._gwbc_fn( temperature, rho, @@ -702,6 +866,24 @@ def forward(state, neighbors): p[i_s], p[j_s], ) + elif self.solver == "DELTA": + out = vmap(self._acceleration_delta_fn)( + dr_i_j, + dist, + rho[i_s], + rho[j_s], + u[i_s], + u[j_s], + v[i_s], + v[j_s], + mass[i_s], + mass[j_s], + eta[i_s], + eta[j_s], + p[i_s], + p[j_s], + fluid_mask[j_s], + ) elif self.solver == "RIE": out = vmap(self._acceleration_riemann_fn)( e_s, diff --git a/tests/test_pf2d.py b/tests/test_pf2d.py index 6c8b7cb..29be385 100644 --- a/tests/test_pf2d.py +++ b/tests/test_pf2d.py @@ -103,7 +103,9 @@ def get_solution(data_path, t_dimless, y_axis): return solutions -@pytest.mark.parametrize("tvf, solver", [(0.0, "SPH"), (1.0, "SPH"), (0.0, "RIE")]) +@pytest.mark.parametrize( + "tvf, solver", [(0.0, "SPH"), (1.0, "SPH"), (0.0, "RIE"), (0.0, "DELTA")] +) def test_pf2d(tvf, solver, tmp_path, setup_simulation): """Test whether the poiseuille flow simulation matches the analytical solution""" y_axis, t_dimless, ref_solutions = setup_simulation diff --git a/validation/db2d.sh b/validation/db2d.sh index 57b7b43..2cbf6f2 100644 --- a/validation/db2d.sh +++ b/validation/db2d.sh @@ -1,7 +1,12 @@ #!/bin/bash ####### 2D Dam break +# Riemann SPH python main.py config=cases/db.yaml case.u_ref=2 case.viscosity=0.0 solver.name=RIE solver.t_end=7.5 solver.dt=0.0002 solver.free_slip=True solver.artificial_alpha=0.0 solver.eta_limiter=3 io.write_every=50 io.data_path=data_valid/db2d_Riemann/ +# Delta SPH +python main.py config=cases/db.yaml case.u_ref=1.95 case.special.H_wall=5.366 case.viscosity=0.0 eos.gamma=7.0 solver.name=DELTA solver.t_end=7.5 solver.dt=0.0002 solver.free_slip=True solver.artificial_alpha=0.0 io.write_every=50 io.data_path=data_valid/db2d_Delta/ + # Run validation script python validation/validate.py --case=2D_DB --src_dir=data_valid/db2d_Riemann/ +python validation/validate.py --case=2D_DB --src_dir=data_valid/db2d_Delta/ diff --git a/validation/pf2d.sh b/validation/pf2d.sh index 7ffa69f..db2a7f7 100755 --- a/validation/pf2d.sh +++ b/validation/pf2d.sh @@ -6,9 +6,11 @@ # Generate data python main.py config=cases/pf.yaml solver.tvf=1.0 io.data_path=data_valid/pf2d_tvf/ python main.py config=cases/pf.yaml solver.tvf=0.0 io.data_path=data_valid/pf2d_notvf/ +python main.py config=cases/pf.yaml solver.tvf=0.0 solver.name=DELTA io.data_path=data_valid/pf2d_delta/ python main.py config=cases/pf.yaml solver.tvf=0.0 solver.name=RIE solver.density_evolution=True io.data_path=data_valid/pf2d_Rie/ # Run validation script python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_tvf/ python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_notvf/ +python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_delta/ python validation/validate.py --case=2D_PF --src_dir=data_valid/pf2d_Rie/ From 0561f513255b0f2f969f6df1f61b4012a455397d Mon Sep 17 00:00:00 2001 From: Artur Toshev <45920489+arturtoshev@users.noreply.github.com> Date: Mon, 8 Jul 2024 01:01:31 +0200 Subject: [PATCH 2/3] bump version to 0.1.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7760f4e..8bd9485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "jax-sph" -version = "0.0.1" +version = "0.1.0" description = "JAX-SPH: Smoothed Particle Hydrodynamics in JAX" authors = ["Artur Toshev ",] maintainers = ["Artur Toshev ",] From 071329334db777cd95850bc6892602f4748f71b1 Mon Sep 17 00:00:00 2001 From: Jonas Erbesdobler Date: Thu, 18 Jul 2024 19:51:28 +0200 Subject: [PATCH 3/3] review fixes --- jax_sph/solver.py | 82 ++++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/jax_sph/solver.py b/jax_sph/solver.py index 3a892d6..fc4dfba 100644 --- a/jax_sph/solver.py +++ b/jax_sph/solver.py @@ -53,7 +53,6 @@ def L_fn(r_ab, kernel_grad, V_b): temp = vmap(L_fn)(r_ij, kernel_grad, V_j) L_mati = jnp.linalg.inv(ops.segment_sum(temp, i_s, N)) - # TODO: check whether this is the same temp = vmap(L_fn)(-r_ij, kernel_grad, V_i) L_matj = jnp.linalg.inv(ops.segment_sum(temp, j_s, N)) @@ -258,8 +257,8 @@ def acceleration_standard_fn( def acceleration_delta_fn_wrapper(kernel_fn, alpha, support, c_ref, rho_ref): - """Standard SPH acceleration according to Adami et al. 2012.""" - """Delta SPH acceleration according to Marrone et al. 2011.""" + """Standard SPH acceleration according to Adami et al. 2012., and + Delta SPH relaxation according to Marrone et al. 2011.""" def acceleration_delta_fn( r_ij, @@ -651,21 +650,20 @@ def __init__( self.is_heat_conduction = is_heat_conduction _beta_fn = limiter_fn_wrapper(eta_limiter, c_ref) - match kernel: - case "CSK": - self._kernel_fn = CubicKernel(h=h_fac * dx, dim=dim) - case "QSK": - self._kernel_fn = QuinticKernel(h=h_fac * dx, dim=dim) - case "WC2K": - self._kernel_fn = WendlandC2Kernel(h=h_fac * dx, dim=dim) - case "WC4K": - self._kernel_fn = WendlandC4Kernel(h=h_fac * dx, dim=dim) - case "WC6K": - self._kernel_fn = WendlandC6Kernel(h=h_fac * dx, dim=dim) - case "GK": - self._kernel_fn = GaussianKernel(h=h_fac * dx, dim=dim) - case "SGK": - self._kernel_fn = SuperGaussianKernel(h=h_fac * dx, dim=dim) + if kernel == "CSK": + self._kernel_fn = CubicKernel(h=h_fac * dx, dim=dim) + elif kernel == "QSK": + self._kernel_fn = QuinticKernel(h=h_fac * dx, dim=dim) + elif kernel == "WC2K": + self._kernel_fn = WendlandC2Kernel(h=h_fac * dx, dim=dim) + elif kernel == "WC4K": + self._kernel_fn = WendlandC4Kernel(h=h_fac * dx, dim=dim) + elif kernel == "WC6K": + self._kernel_fn = WendlandC6Kernel(h=h_fac * dx, dim=dim) + elif kernel == "GK": + self._kernel_fn = GaussianKernel(h=h_fac * dx, dim=dim) + elif kernel == "SGK": + self._kernel_fn = SuperGaussianKernel(h=h_fac * dx, dim=dim) self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos) ( @@ -674,27 +672,31 @@ def __init__( self._heat_bc, ) = gwbc_fn_riemann_wrapper(is_free_slip, is_heat_conduction) self._acceleration_tvf_fn = acceleration_tvf_fn_wrapper(self._kernel_fn) - self._acceleration_riemann_fn = acceleration_riemann_fn_wrapper( - self._kernel_fn, eos, _beta_fn, eta_limiter - ) - self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn) - self._acceleration_delta_fn = acceleration_delta_fn_wrapper( - self._kernel_fn, - self.diff_alpha, - h_fac * dx, - self.c_ref, - self.eos.rho_ref, - ) + + if solver == "SPH": + self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn) + self._rho_evol_fn = rho_evol_fn + elif solver == "RIE": + self._acceleration_fn = acceleration_riemann_fn_wrapper( + self._kernel_fn, eos, _beta_fn, eta_limiter + ) + self._rho_evol_fn = rho_evol_riemann_fn_wrapper(self._kernel_fn, eos, c_ref) + elif solver == "DELTA": + self._acceleration_fn = acceleration_delta_fn_wrapper( + self._kernel_fn, + self.diff_alpha, + h_fac * dx, + self.c_ref, + self.eos.rho_ref, + ) + self._rho_evol_fn = rho_evol_fn_delta_wrapper( + self.diff_delta, h_fac * dx, self.c_ref, self._kernel_fn + ) + self._artificial_viscosity_fn = artificial_viscosity_fn_wrapper( dx, artificial_alpha ) self._wall_phi_vec = wall_phi_vec_wrapper(self._kernel_fn) - self._rho_evol_riemann_fn = rho_evol_riemann_fn_wrapper( - self._kernel_fn, eos, c_ref - ) - self._rho_evol_detla_fn = rho_evol_fn_delta_wrapper( - self.diff_delta, h_fac * dx, self.c_ref, self._kernel_fn - ) self._temperature_derivative = temperature_derivative_wrapper(self._kernel_fn) def forward_wrapper(self): @@ -746,14 +748,14 @@ def forward(state, neighbors): # update evolution if self.is_rho_evol and (self.solver == "SPH"): - rho, drhodt = rho_evol_fn( + rho, drhodt = self._rho_evol_fn( rho, mass, u, grad_w_dist, i_s, j_s, self.dt, N ) if self.is_rho_renorm: rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N) elif self.is_rho_evol and (self.solver == "DELTA"): - rho, drhodt = self._rho_evol_detla_fn( + rho, drhodt = self._rho_evol_fn( rho, mass, dr_i_j, @@ -768,7 +770,7 @@ def forward(state, neighbors): if self.is_rho_renorm: rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N) elif self.is_rho_evol and (self.solver == "RIE"): - temp = vmap(self._rho_evol_riemann_fn)( + temp = vmap(self._rho_evol_fn)( e_s, rho[i_s], rho[j_s], @@ -867,7 +869,7 @@ def forward(state, neighbors): p[j_s], ) elif self.solver == "DELTA": - out = vmap(self._acceleration_delta_fn)( + out = vmap(self._acceleration_fn)( dr_i_j, dist, rho[i_s], @@ -885,7 +887,7 @@ def forward(state, neighbors): fluid_mask[j_s], ) elif self.solver == "RIE": - out = vmap(self._acceleration_riemann_fn)( + out = vmap(self._acceleration_fn)( e_s, dr_i_j, dist,