diff --git a/jax_sph/solver.py b/jax_sph/solver.py index 8acac63..826297c 100644 --- a/jax_sph/solver.py +++ b/jax_sph/solver.py @@ -47,6 +47,7 @@ def rho_evol_riemann_fn( wall_mask_j, n_w_j, g_ext_i, + u_tilde_j, **kwargs, ): # Compute unit vector, above eq. (6), Zhang (2017) @@ -61,9 +62,12 @@ def rho_evol_riemann_fn( rho_L = rho_i # u_w from eq. (15), Yang (2020) + # u_d = 2 * u_i - u_j + u_d = 2 * u_i - u_tilde_j u_R = jnp.where( wall_mask_j == 1, - -u_L + 2 * jnp.dot(u_j, n_w_j), + # -u_L + 2 * jnp.dot(u_j, n_w_j), + jnp.dot(u_d, -n_w_j), jnp.dot(u_j, -e_ij), ) p_R = jnp.where(wall_mask_j == 1, p_L + rho_L * jnp.dot(g_ext_i, -r_ij), p_j) @@ -595,6 +599,10 @@ def forward(state, neighbors): ) n_w = jnp.where(jnp.absolute(n_w) < EPS, 0.0, n_w) + ##### Riemann velocity BCs + if self.is_bc_trick and (self.solver == "RIE"): + u_tilde = self._Riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N) + ##### Density summation or evolution # update evolution @@ -621,6 +629,7 @@ def forward(state, neighbors): wall_mask[j_s], n_w[j_s], g_ext[i_s], + u_tilde[j_s], ) drhodt = ops.segment_sum(temp, i_s, N) * fluid_mask rho = rho + self.dt * drhodt @@ -644,7 +653,7 @@ def forward(state, neighbors): ) elif self.is_bc_trick and (self.solver == "RIE"): mask = self._free_weight(fluid_mask[i_s], tag[i_s]) - u_tilde = self._Riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N) + # u_tilde = self._Riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N) temperature = self._heat_bc( fluid_mask[j_s], w_dist, temperature, i_s, j_s, tag, N )