Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a helper function for bc related warp functions #98

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC
from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition
from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class DoNothingBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class EquilibriumBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class ExtrapolationOutflowBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
BoundaryCondition,
ImplementationStep,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class FullwayBounceBackBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_grads_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class GradsApproximationBC(BoundaryCondition):
Expand Down
3 changes: 0 additions & 3 deletions xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)


class HalfwayBounceBackBC(BoundaryCondition):
Expand Down
100 changes: 13 additions & 87 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from xlb.precision_policy import PrecisionPolicy
from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux
from xlb.operator.boundary_condition import ZouHeBC, HelperFunctionsBC
from xlb.operator.macroscopic import SecondMoment as MomentumFlux


class RegularizedBC(ZouHeBC):
Expand Down Expand Up @@ -64,7 +62,6 @@ def __init__(
indices,
mesh_vertices,
)
# Overwrite the boundary condition registry id with the bc_type in the name
self.momentum_flux = MomentumFlux()

@partial(jit, static_argnums=(0,), inline=True)
Expand Down Expand Up @@ -127,83 +124,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
# compute Qi tensor and store it in self
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_w = self.velocity_set.w
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
_qi = self.velocity_set.qi
# TODO: related to _c_float: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def regularize_fpop(
fpop: Any,
feq: Any,
):
"""
Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
"""
# Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
f_neq = fpop - feq
PiNeq = self.momentum_flux.warp_functional(f_neq)

# Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
nt = _d * (_d + 1) // 2
for l in range(_q):
QiPi1 = self.compute_dtype(0.0)
for t in range(nt):
QiPi1 += _qi[l, t] * PiNeq[t]

# assign all populations based on eq 45 of Latt et al (2008)
# fneq ~ f^1
fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1
fpop[l] = feq[l] + fpop1
return fpop

@wp.func
def functional_velocity(
Expand All @@ -219,7 +145,7 @@ def functional_velocity(
_f = f_post

# Find normal vector
normals = get_normal_vectors(missing_mask)
normals = bc_helper.get_normal_vectors(missing_mask)

# Find the value of u from the missing directions
for l in range(_q):
Expand All @@ -231,18 +157,18 @@ def functional_velocity(
break

# calculate rho
fsum = _get_fsum(_f, missing_mask)
fsum = bc_helper.get_bc_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
_rho = fsum / (self.compute_dtype(1.0) + unormal)

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)

# Regularize the boundary fpop
_f = regularize_fpop(_f, feq)
_f = bc_helper.regularize_fpop(_f, feq)
return _f

@wp.func
Expand All @@ -259,7 +185,7 @@ def functional_pressure(
_f = f_post

# Find normal vector
normals = get_normal_vectors(missing_mask)
normals = bc_helper.get_normal_vectors(missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
Expand All @@ -269,16 +195,16 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, missing_mask)
fsum = bc_helper.get_bc_fsum(_f, missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)

# Regularize the boundary fpop
_f = regularize_fpop(_f, feq)
_f = bc_helper.regularize_fpop(_f, feq)
return _f

if self.bc_type == "velocity":
Expand Down
66 changes: 10 additions & 56 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
ImplementationStep,
BoundaryCondition,
)
from xlb.operator.boundary_condition.boundary_condition_registry import (
boundary_condition_registry,
)
from xlb.operator.boundary_condition import HelperFunctionsBC
from xlb.operator.equilibrium import QuadraticEquilibrium
import jax


class ZouHeBC(BoundaryCondition):
Expand Down Expand Up @@ -277,55 +274,12 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post

def _construct_warp(self):
# assign placeholders for both u and rho based on prescribed_value
# load helper functions
bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
# Set local constants
_d = self.velocity_set.d
_q = self.velocity_set.q

# Set local constants TODO: This is a hack and should be fixed with warp update
# _u_vec = wp.vec(_d, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_opp_indices = self.velocity_set.opp_indices
_c = self.velocity_set.c
_c_float = self.velocity_set.c_float
# TODO: this is way less than ideal. we should not be making new types

@wp.func
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
fsum_known = self.compute_dtype(0.0)
fsum_middle = self.compute_dtype(0.0)
for l in range(_q):
if missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += self.compute_dtype(2.0) * fpop[l]
elif missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle

@wp.func
def get_normal_vectors(
missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])

@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
missing_mask: Any,
):
for l in range(_q):
if missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop

@wp.func
def functional_velocity(
Expand All @@ -341,10 +295,10 @@ def functional_velocity(
_f = _f_post

# Find normal vector
normals = get_normal_vectors(_missing_mask)
normals = bc_helper.get_normal_vectors(_missing_mask)

# calculate rho
fsum = _get_fsum(_f, _missing_mask)
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = self.compute_dtype(0.0)

# Find the value of u from the missing directions
Expand All @@ -364,7 +318,7 @@ def functional_velocity(

# impose non-equilibrium bounceback
_feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, _feq, _missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, _feq, _missing_mask)
return _f

@wp.func
Expand All @@ -381,7 +335,7 @@ def functional_pressure(
_f = _f_post

# Find normal vector
normals = get_normal_vectors(_missing_mask)
normals = bc_helper.get_normal_vectors(_missing_mask)

# Find the value of rho from the missing directions
for q in range(_q):
Expand All @@ -391,13 +345,13 @@ def functional_pressure(
break

# calculate velocity
fsum = _get_fsum(_f, _missing_mask)
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals

# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
_f = bounceback_nonequilibrium(_f, feq, _missing_mask)
_f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask)
return _f

if self.bc_type == "velocity":
Expand Down
Loading
Loading