Skip to content

Commit

Permalink
Merge branch 'PR_review' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Aug 29, 2024
2 parents fbcf525 + 8a2255f commit 31895f9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 30 deletions.
24 changes: 12 additions & 12 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(

# Unpack the two warp functionals needed for this BC!
if self.compute_backend == ComputeBackend.WARP:
self.warp_functional_poststream, self.warp_functional_postcollision = self.warp_functional
self.warp_functional, self.prepare_bc_auxilary_data = self.warp_functional

def _get_normal_vec(self, indices):
# Get the frequency count and most common element directly
Expand Down Expand Up @@ -157,10 +157,10 @@ def get_normal_vectors_3d(

# Construct the functionals for this BC
@wp.func
def functional_poststream(
def functional(
f_pre: Any,
f_post: Any,
f_nbr: Any,
f_aux: Any,
missing_mask: Any,
):
# Post-streaming values are only modified at missing direction
Expand All @@ -173,7 +173,7 @@ def functional_poststream(
return _f

@wp.func
def functional_postcollision(
def prepare_bc_auxilary_data(
f_pre: Any,
f_post: Any,
f_aux: Any,
Expand Down Expand Up @@ -202,7 +202,7 @@ def kernel2d(

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index)
_faux = _f_vec()
_f_aux = _f_vec()

# special preparation of auxiliary data
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
Expand All @@ -215,13 +215,13 @@ def kernel2d(
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
_faux[l] = _f_pre[l, pull_index[0], pull_index[1]]
_f_aux[l] = _f_pre[l, pull_index[0], pull_index[1]]

# Apply the boundary condition
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
# TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both
# TODO: is there any way for this BC to have a meaningful kernel given that it has two steps after both
# collision and streaming?
_f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask)
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
else:
_f = _f_post

Expand All @@ -243,7 +243,7 @@ def kernel3d(

# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index)
_faux = _f_vec()
_f_aux = _f_vec()

# special preparation of auxiliary data
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
Expand All @@ -256,13 +256,13 @@ def kernel3d(
for d in range(self.velocity_set.d):
pull_index[d] = index[d] - (_c[d, l] + nv[d])
# The following is the post-streaming values of the neighbor cell
_faux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]]
_f_aux[l] = _f_pre[l, pull_index[0], pull_index[1], pull_index[2]]

# Apply the boundary condition
if _boundary_id == wp.uint8(ExtrapolationOutflowBC.id):
# TODO: is there any way for this BC to have a meaninful kernel given that it has two steps after both
# collision and streaming?
_f = functional_poststream(_f_pre, _f_post, _faux, _missing_mask)
_f = functional(_f_pre, _f_post, _f_aux, _missing_mask)
else:
_f = _f_post

Expand All @@ -272,7 +272,7 @@ def kernel3d(

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return [functional_poststream, functional_postcollision], kernel
return (functional, prepare_bc_auxilary_data), kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
Expand Down
14 changes: 2 additions & 12 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,7 @@ def __init__(
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

@wp.func
def functional_postcollision(
f_pre: Any,
f_post: Any,
f_aux: Any,
missing_mask: Any,
):
return f_post

@wp.func
def functional_poststream(
def prepare_bc_auxilary_data(
f_pre: Any,
f_post: Any,
f_aux: Any,
Expand Down Expand Up @@ -123,8 +114,7 @@ def _get_thread_data_3d(
if self.compute_backend == ComputeBackend.WARP:
self._get_thread_data_2d = _get_thread_data_2d
self._get_thread_data_3d = _get_thread_data_3d
self.warp_functional_poststream = functional_poststream
self.warp_functional_postcollision = functional_postcollision
self.prepare_bc_auxilary_data = prepare_bc_auxilary_data

@partial(jit, static_argnums=(0,), inline=True)
def prepare_bc_auxilary_data(self, f_pre, f_post, boundary_mask, missing_mask):
Expand Down
13 changes: 7 additions & 6 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def apply_post_streaming_bc(
f_post = self.RegularizedBC_pressure.warp_functional(f_pre, f_post, f_aux, missing_mask)
elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC:
# Regularized boundary condition (bc type = velocity)
f_post = self.ExtrapolationOutflowBC.warp_functional_poststream(f_pre, f_post, f_aux, missing_mask)
f_post = self.ExtrapolationOutflowBC.warp_functional(f_pre, f_post, f_aux, missing_mask)
return f_post

@wp.func
Expand All @@ -160,7 +160,7 @@ def apply_post_collision_bc(
elif _boundary_id == bc_struct.id_ExtrapolationOutflowBC:
# f_aux is the neighbour's post-streaming values
# Storing post-streaming data in directions that leave the domain
f_post = self.ExtrapolationOutflowBC.warp_functional_postcollision(f_pre, f_post, f_aux, missing_mask)
f_post = self.ExtrapolationOutflowBC.prepare_bc_auxilary_data(f_pre, f_post, f_aux, missing_mask)

return f_post

Expand Down Expand Up @@ -221,7 +221,7 @@ def get_thread_data_3d(
return f_post_collision, _missing_mask

@wp.func
def prepare_bc_auxilary_data_2d(
def get_bc_auxilary_data_2d(
f_0: wp.array3d(dtype=Any),
index: Any,
_boundary_id: Any,
Expand All @@ -244,7 +244,7 @@ def prepare_bc_auxilary_data_2d(
return f_auxiliary

@wp.func
def prepare_bc_auxilary_data_3d(
def get_bc_auxilary_data_3d(
f_0: wp.array4d(dtype=Any),
index: Any,
_boundary_id: Any,
Expand Down Expand Up @@ -287,7 +287,7 @@ def kernel2d(

# Prepare auxilary data for BC (if applicable)
_boundary_id = boundary_mask[0, index[0], index[1]]
f_auxiliary = prepare_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct)
f_auxiliary = get_bc_auxilary_data_2d(f_0, index, _boundary_id, _missing_mask, bc_struct)

# Apply post-streaming type boundary conditions
f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct)
Expand Down Expand Up @@ -335,7 +335,7 @@ def kernel3d(

# Prepare auxilary data for BC (if applicable)
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
f_auxiliary = prepare_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct)
f_auxiliary = get_bc_auxilary_data_3d(f_0, index, _boundary_id, _missing_mask, bc_struct)

# Apply post-streaming type boundary conditions
f_post_stream = apply_post_streaming_bc(f_post_collision, f_post_stream, f_auxiliary, _missing_mask, _boundary_id, bc_struct)
Expand Down Expand Up @@ -380,6 +380,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):

# Setting the Struct attributes and active BC classes based on the BC class names
bc_fallback = self.boundary_conditions[0]
# TODO: what if self.boundary_conditions is an empty list e.g. when we have periodic BC all around!
for var in vars(bc_struct):
if var not in active_bc_list and not var.startswith("_"):
# set unassigned boundaries to the maximum integer in uint8
Expand Down

0 comments on commit 31895f9

Please sign in to comment.