Skip to content

Commit

Permalink
Added a helper function for BCs to avoid repeated definition of ident…
Browse files Browse the repository at this point in the history
…ical warp functions.
  • Loading branch information
hsalehipour committed Jan 2, 2025
1 parent 237ac22 commit 74c5cdf
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 189 deletions.
97 changes: 12 additions & 85 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux
from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC


class RegularizedBC(ZouHeBC):
Expand Down Expand Up @@ -64,7 +63,6 @@ def __init__(
indices,
mesh_vertices,
)
# Overwrite the boundary condition registry id with the bc_type in the name
self.momentum_flux = MomentumFlux()

@partial(jit, static_argnums=(0,), inline=True)
Expand Down Expand Up @@ -127,83 +125,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
_qi = self.velocity_set.qi
# TODO: related to _c_float: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = self.momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
for l in range(_q):
QiPi1 = self.compute_dtype(0.0)
for t in range(nt):
QiPi1 += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1
fpop[l] = feq[l] + fpop1
return fpop

@wp.func
def functional_velocity(
Expand All @@ -219,7 +146,7 @@ def functional_velocity(
_f = f_post

# Find normal vector
normals = get_normal_vectors(missing_mask)
normals = bc_helper.get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
for l in range(_q):
Expand All @@ -231,18 +158,18 @@ def functional_velocity(
break

# calculate rho
fsum = _get_fsum(_f, missing_mask)
fsum = bc_helper.get_bc_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
_rho = fsum / (self.compute_dtype(1.0) + unormal)

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)

# Regularize the boundary fpop
_f = regularize_fpop(_f, feq)
_f = bc_helper.regularize_fpop(_f, feq)
return _f

@wp.func
Expand All @@ -259,7 +186,7 @@ def functional_pressure(
_f = f_post

# Find normal vector
normals = get_normal_vectors(missing_mask)
normals = bc_helper.get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
Expand All @@ -269,16 +196,16 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
fsum = bc_helper.get_bc_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)

# Regularize the boundary fpop
_f = regularize_fpop(_f, feq)
_f = bc_helper.regularize_fpop(_f, feq)
return _f

if self.bc_type == "velocity":
Expand Down
65 changes: 10 additions & 55 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)
from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC
from xlb.operator.equilibrium import QuadraticEquilibrium
import jax

Expand Down Expand Up @@ -277,55 +275,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
# TODO: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def functional_velocity(
Expand All @@ -341,10 +296,10 @@ def functional_velocity(
_f = _f_post

# Find normal vector
normals = get_normal_vectors(_missing_mask)
normals = bc_helper.get_normal_vectors(_missing_mask)

# calculate rho
fsum = _get_fsum(_f, _missing_mask)
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
Expand All @@ -364,7 +319,7 @@ def functional_velocity(

# impose non-equilibrium bounceback
_feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, _feq, _missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, _feq, _missing_mask)
return _f

@wp.func
Expand All @@ -381,7 +336,7 @@ def functional_pressure(
_f = _f_post

# Find normal vector
normals = get_normal_vectors(_missing_mask)
normals = bc_helper.get_normal_vectors(_missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
Expand All @@ -391,13 +346,13 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, _missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask)
return _f

if self.bc_type == "velocity":
Expand Down
55 changes: 6 additions & 49 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xlb.operator.operator import Operator
from xlb import DefaultConfig
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC


# Enum for implementation step
Expand Down Expand Up @@ -71,53 +72,6 @@ def __init__(
# A flag for BCs that need auxilary data recovery after streaming
self.needs_aux_recovery = False

if self.compute_backend == ComputeBackend.WARP:
# Set local constants TODO: This is a hack and should be fixed with warp update
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

@wp.func
def update_bc_auxilary_data(
index: Any,
timestep: Any,
missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
f_post: Any,
):
return f_post

@wp.func
def _get_thread_data(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
index: wp.vec3i,
):
# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = bc_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]])
_f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]])

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
return _f_pre, _f_post, _boundary_id, _missing_mask

# Construct some helper warp functions for getting tid data
if self.compute_backend == ComputeBackend.WARP:
self._get_thread_data = _get_thread_data
self.update_bc_auxilary_data = update_bc_auxilary_data

@partial(jit, static_argnums=(0,), inline=True)
def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
"""
Expand All @@ -131,6 +85,7 @@ def _construct_kernel(self, functional):
Constructs the warp kernel for the boundary condition.
The functional is specific to each boundary condition and should be passed as an argument.
"""
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
_id = wp.uint8(self.id)

# Construct the warp kernel
Expand All @@ -146,7 +101,7 @@ def kernel(
index = wp.vec3i(i, j, k)

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data(f_pre, f_post, bc_mask, missing_mask, index)
_f_pre, _f_post, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_pre, f_post, bc_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == _id:
Expand All @@ -165,6 +120,8 @@ def _construct_aux_data_init_kernel(self, functional):
"""
Constructs the warp kernel for the auxilary data recovery.
"""
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)

_id = wp.uint8(self.id)
_opp_indices = self.velocity_set.opp_indices
_num_of_aux_data = self.num_of_aux_data
Expand All @@ -182,7 +139,7 @@ def aux_data_init_kernel(
index = wp.vec3i(i, j, k)

# read tid data
_f_0, _f_1, _boundary_id, _missing_mask = self._get_thread_data(f_0, f_1, bc_mask, missing_mask, index)
_f_0, _f_1, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_0, f_1, bc_mask, missing_mask, index)

# Apply the functional
if _boundary_id == _id:
Expand Down
Loading

0 comments on commit 74c5cdf

Please sign in to comment.