Skip to content

Commit

Permalink
Merge pull request Autodesk#99 from hsalehipour/improved_bc_encoding
Browse files Browse the repository at this point in the history
Used center of f_1 as an additional storage and also fixed some bugs
  • Loading branch information
hsalehipour authored Jan 2, 2025
2 parents c48e0ec + 77ecdf6 commit 5340e6c
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 55 deletions.
4 changes: 2 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def bc_profile(self):
@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = self.precision_policy.store_precision.wp_dtype(index[1])
z = self.precision_policy.store_precision.wp_dtype(index[2])
y = wp.float32(index[1])
z = wp.float32(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
Expand Down
18 changes: 6 additions & 12 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,10 @@ def functional_velocity(
normals = get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
for l in range(_q):
# Since we are only considering normal velocity, we only need to find one value
if missing_mask[l] == wp.uint8(1):
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[_opp_indices[l], index[0], index[1], index[2]]
_u = -prescribed_value * normals
break
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[0, index[0], index[1], index[2]]
_u = -prescribed_value * normals

# calculate rho
fsum = _get_fsum(_f, missing_mask)
Expand Down Expand Up @@ -262,11 +259,8 @@ def functional_pressure(
normals = get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
# Since we need only one scalar value, we only need to find one value
if missing_mask[q] == wp.uint8(1):
_rho = f_1[_opp_indices[q], index[0], index[1], index[2]]
break
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
_rho = f_1[0, index[0], index[1], index[2]]

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
Expand Down
62 changes: 25 additions & 37 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def __init__(
if non_zero_count > 1:
raise ValueError("This BC only supports normal prescribed values (only one non-zero element allowed)")

# Prescribed value for this BC must be:
# a single non-zero number associated with normal velocity magnitude for velocity BC OR
# a single non-zero number associated with pressure BC OR
# a vector of zeros associated with no-slip BC.
# Accounting for all scenarios here.
if self.compute_backend is ComputeBackend.WARP:
idx = np.nonzero(prescribed_value)[0]
prescribed_value = prescribed_value[idx][0] if idx.size else 0.0
prescribed_value = self.precision_policy.store_precision.wp_dtype(prescribed_value)
self.prescribed_value = prescribed_value
self.profile = self._create_constant_prescribed_profile()

Expand All @@ -107,28 +116,14 @@ def __init__(
self.needs_padding = True

def _create_constant_prescribed_profile(self):
if self.bc_type == "velocity":

@wp.func
def prescribed_profile_warp(index: wp.vec3i):
# Get the non-zero value from prescribed_value
value = wp.static(
self.precision_policy.store_precision.wp_dtype(float(self.prescribed_value[np.nonzero(self.prescribed_value)[0][0]]))
)
return wp.vec(value, length=1)

def prescribed_profile_jax():
return jnp.array(self.prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1)
_prescribed_value = self.prescribed_value

else: # pressure

@wp.func
def prescribed_profile_warp(index: wp.vec3i):
value = wp.static(self.precision_policy.store_precision.wp_dtype(self.prescribed_value))
return wp.vec(value, length=1)
@wp.func
def prescribed_profile_warp(index: wp.vec3i):
return wp.vec(_prescribed_value, length=1)

def prescribed_profile_jax():
return jnp.array(self.prescribed_value)
def prescribed_profile_jax():
return jnp.array(_prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1)

if self.compute_backend == ComputeBackend.JAX:
return prescribed_profile_jax
Expand Down Expand Up @@ -332,8 +327,8 @@ def functional_velocity(
index: Any,
timestep: Any,
_missing_mask: Any,
f_pre: Any,
f_post: Any,
f_0: Any,
f_1: Any,
_f_pre: Any,
_f_post: Any,
):
Expand All @@ -348,14 +343,10 @@ def functional_velocity(
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
for l in range(_q):
# Since we are only considering normal velocity, we only need to find one value (all values are the same in the missing directions)
if _missing_mask[l] == wp.uint8(1):
# Create velocity vector by multiplying the prescribed value with the normal vector
# TODO: This can be optimized by saving _missing_mask[l] in the bc class later since it is the same for all boundary cells
prescribed_value = f_post[_opp_indices[l], index[0], index[1], index[2]]
_u = -prescribed_value * normals
break
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
# Create velocity vector by multiplying the prescribed value with the normal vector
prescribed_value = f_1[0, index[0], index[1], index[2]]
_u = -prescribed_value * normals

for d in range(_d):
unormal += _u[d] * normals[d]
Expand All @@ -372,8 +363,8 @@ def functional_pressure(
index: Any,
timestep: Any,
_missing_mask: Any,
f_pre: Any,
f_post: Any,
f_0: Any,
f_1: Any,
_f_pre: Any,
_f_post: Any,
):
Expand All @@ -384,11 +375,8 @@ def functional_pressure(
normals = get_normal_vectors(_missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
# Since we need only one scalar value, we only need to find one value (all values are the same in the missing directions)
if _missing_mask[q] == wp.uint8(1):
_rho = f_post[_opp_indices[q], index[0], index[1], index[2]]
break
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
_rho = f_1[0, index[0], index[1], index[2]]

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
Expand Down
9 changes: 7 additions & 2 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,13 @@ def aux_data_init_kernel(
prescribed_values = functional(index)
# Write the result for all q directions, but only store up to num_of_aux_data
# TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions
counter = wp.int32(0)
for l in range(self.velocity_set.q):

# The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center.
f_1[0, index[0], index[1], index[2]] = self.store_dtype(prescribed_values[0])
counter = wp.int32(1)

# The other remaining BC auxiliary data are stored in missing directions of f_1.
for l in range(1, self.velocity_set.q):
if _missing_mask[l] == wp.uint8(1) and counter < _num_of_aux_data:
f_1[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(prescribed_values[counter])
counter += 1
Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class QuadraticEquilibrium(Equilibrium):
def jax_implementation(self, rho, u):
cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0))
usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True)
w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1))
w = self.velocity_set.w.reshape((-1,) + (1,) * self.velocity_set.d)
feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr)
return feq

Expand Down
11 changes: 10 additions & 1 deletion xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,21 @@ def apply_aux_recovery_bc(
f_0: Any,
_f1_thread: Any,
):
# Note:
# In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or
# (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether
# the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by
# writting the thread values of f_1 (i.e._f1_thread) into f_0.

# Unroll the loop over boundary conditions
for i in range(wp.static(len(self.boundary_conditions))):
if wp.static(self.boundary_conditions[i].needs_aux_recovery):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
# Perform the swapping of data
for l in range(self.velocity_set.q):
# (i) Recover the values stored in the central index of f_1
f_0[0, index[0], index[1], index[2]] = self.store_dtype(_f1_thread[0])
# (ii) Recover the values stored in the missing directions of f_1
for l in range(1, self.velocity_set.q):
if _missing_mask[l] == wp.uint8(1):
f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]])

Expand Down

0 comments on commit 5340e6c

Please sign in to comment.