From a5d4aea033b2a026e8b2ebb1f567acda6b2bd197 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Fri, 7 Jun 2024 21:56:48 +0200 Subject: [PATCH] get latest partition.py from lagrangebench. From now on this will be the only partition.py copy --- jax_sph/partition.py | 621 +++++++++++++--------------------------- tests/test_neighbors.py | 137 +++++++++ 2 files changed, 338 insertions(+), 420 deletions(-) create mode 100644 tests/test_neighbors.py diff --git a/jax_sph/partition.py b/jax_sph/partition.py index 15869b3..d0256db 100644 --- a/jax_sph/partition.py +++ b/jax_sph/partition.py @@ -1,8 +1,4 @@ -"""Neighbors search backends. - -Source: -https://github.com/tumaer/lagrangebench/blob/main/lagrangebench/case_setup/partition.py -""" +"""Neighbors search backends.""" from functools import partial from typing import Optional @@ -13,23 +9,26 @@ import numpy as np import numpy as onp from jax import jit +from jax_md import space from jax_md.partition import ( - CellList, MaskFn, NeighborFn, NeighborList, NeighborListFns, NeighborListFormat, + PartitionError, + PartitionErrorCode, _displacement_or_metric_to_metric_sq, _neighboring_cells, cell_list, is_format_valid, is_sparse, shift_array, - space, ) from jax_md.partition import neighbor_list as vmap_neighbor_list +PEC = PartitionErrorCode + def get_particle_cells(idx, cl_capacity, N): """ @@ -54,7 +53,7 @@ def get_particle_cells(idx, cl_capacity, N): def _scan_neighbor_list( displacement_or_metric: space.DisplacementOrMetricFn, - box_size: space.Box, + box: space.Box, r_cutoff: float, dr_threshold: float = 0.0, capacity_multiplier: float = 1.25, @@ -63,7 +62,7 @@ def _scan_neighbor_list( custom_mask_function: Optional[MaskFn] = None, fractional_coordinates: bool = False, format: NeighborListFormat = NeighborListFormat.Sparse, - num_partitions: int = 1, + num_partitions: int = 8, **static_kwargs, ) -> NeighborFn: """Modified JAX-MD neighbor list function that uses `lax.scan` to compute the @@ -116,7 +115,7 @@ def body_fn(i, state): Args: displacement: A function `d(R_a, R_b)` that computes the displacement between pairs of points. - box_size: Either a float specifying the size of the box or an array of + box: Either a float specifying the size of the box or an array of shape `[spatial_dim]` specifying the box size in each spatial dimension. r_cutoff: A scalar specifying the neighborhood radius. dr_threshold: A scalar specifying the maximum distance particles can move @@ -146,18 +145,17 @@ def body_fn(i, state): A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list. """ - assert disable_cell_list is False, "Works only with a cell list" + assert not fractional_coordinates, "Works only with real coordinates" assert format == NeighborListFormat.Sparse, "Works only with sparse neighbor list" assert custom_mask_function is None, "Custom masking not implemented" - # assert mask_self == False, "Self edges cannot be excluded for now" is_format_valid(format) - box_size = lax.stop_gradient(box_size) + box = lax.stop_gradient(box) r_cutoff = lax.stop_gradient(r_cutoff) dr_threshold = lax.stop_gradient(dr_threshold) - box_size = jnp.float32(box_size) + box = jnp.float32(box) cutoff = r_cutoff + dr_threshold cutoff_sq = cutoff**2 @@ -165,430 +163,167 @@ def body_fn(i, state): metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric) cell_size = cutoff - if fractional_coordinates: - cell_size = cutoff / box_size - box_size = ( - jnp.float32(box_size) - if onp.isscalar(box_size) - else onp.ones_like(box_size, jnp.float32) - ) + assert jnp.all(cell_size < box / 3.0), "Don't use scan with very few cells" - assert jnp.all(cell_size < box_size / 3.0), "Don't use scan with very few cells" + def neighbor_list_fn( + position: jnp.ndarray, + neighbors: Optional[NeighborList] = None, + extra_capacity: int = 0, + **kwargs, + ) -> NeighborList: + def neighbor_fn(position_and_error, max_occupancy=None): + position, err = position_and_error + N, dim = position.shape + cl_fn = None + cl = None + cell_size = None - cl_fn = cell_list(box_size, cell_size, capacity_multiplier) + if neighbors is None: # cl.shape = (nx, ny, nz, cell_capacity, dim) + cell_size = cutoff + cl_fn = cell_list(box, cell_size, capacity_multiplier) + cl = cl_fn.allocate(position, extra_capacity=extra_capacity) + else: + cell_size = neighbors.cell_size + cl_fn = neighbors.cell_list_fn + if cl_fn is not None: + cl = cl_fn.update(position, neighbors.cell_list_capacity) - @jit - def cell_list_candidate_fn(cl: CellList, position: jnp.ndarray) -> jnp.ndarray: - N, dim = position.shape - - idx = cl.id_buffer - - cell_idx = [idx] - - for dindex in _neighboring_cells( - dim - ): # here the expansion happens over all adjacent cells happens - if onp.all(dindex == 0): - continue - cell_idx += [shift_array(idx, dindex)] # 27* (nx,ny,nz,cell_capacity, 1) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) - cell_idx = cell_idx[..., jnp.newaxis, :, :] # (nx,ny,nz,1,27*cell_capacity, 1) - cell_idx = jnp.broadcast_to( - cell_idx, idx.shape[:-1] + cell_idx.shape[-2:] - ) # (nx,ny,nz,cell_capacity,27*cell_capacity) TODO: memory blows up here - - def copy_values_from_cell(value, cell_value, cell_id): - scatter_indices = jnp.reshape(cell_id, (-1,)) # (nx*ny*nz*cell_capacity) - cell_value = jnp.reshape( - cell_value, (-1,) + cell_value.shape[-2:] - ) # (nx*ny*nz*cell_capacity, 27*cell_capacity, 1) - return value.at[scatter_indices].set(cell_value) - - neighbor_idx = jnp.zeros( - (N + 1,) + cell_idx.shape[-2:], jnp.int32 - ) # (N, 27*cell_capacity, 1) TODO: too much memory - neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) - return neighbor_idx[:-1, :, 0] # shape (N, 27*cell_capacity) + err = err.update(PEC.CELL_LIST_OVERFLOW, cl.did_buffer_overflow) + cl_capacity = cl.cell_capacity - @jit - def prune_neighbor_list_sparse( - position: jnp.ndarray, idx: jnp.ndarray, **kwargs - ) -> jnp.ndarray: - d = partial(metric_sq, **kwargs) - d = space.map_bond(d) + idx = cl.id_buffer - N = position.shape[0] - sender_idx = jnp.broadcast_to(jnp.arange(N)[:, None], idx.shape) + cell_idx = [idx] # shape: (nx, ny, nz, cell_capacity, 1) - sender_idx = jnp.reshape( - sender_idx, (-1,) - ) # (N, 27*cell_capacity) -> (N*27*cell_capacity) - receiver_idx = jnp.reshape(idx, (-1,)) - dR = d( - position[sender_idx], position[receiver_idx] - ) # (N*27*cell_capacity) eventually 3x during computation + for dindex in _neighboring_cells(dim): + if onp.all(dindex == 0): + continue + cell_idx += [shift_array(idx, dindex)] - mask = (dR < cutoff_sq) & (receiver_idx < N) - if format is NeighborListFormat.OrderedSparse: - mask = mask & (receiver_idx < sender_idx) + cell_idx = jnp.concatenate(cell_idx, axis=-2) + cell_idx = jnp.reshape(cell_idx, (-1, cell_idx.shape[-2])) + num_cells, considered_neighbors = cell_idx.shape - out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) + particle_cells = get_particle_cells(idx, cl_capacity, N) - cumsum = jnp.cumsum(mask) - index = jnp.where( - mask, cumsum - 1, len(receiver_idx) - 1 - ) # 7th object of shape (N*27*cell_capacity) - receiver_idx = out_idx.at[index].set(receiver_idx) - sender_idx = out_idx.at[index].set(sender_idx) - max_occupancy = cumsum[-1] + d = partial(metric_sq, **kwargs) + d = space.map_bond(d) - return jnp.stack((receiver_idx, sender_idx)), max_occupancy + # number of particles per partition N_sub + # np.ceil used to pad last partition with < num_partitions entries + N_sub = int(np.ceil(N / num_partitions)) + num_pad = N_sub * num_partitions - N + particle_cells = jnp.pad( + particle_cells, + ( + 0, + num_pad, + ), + constant_values=-1, + ) - def neighbor_list_fn( - position: jnp.ndarray, - neighbors: Optional[NeighborList] = None, - extra_capacity: int = 0, - **kwargs, - ) -> NeighborList: - nbrs = neighbors + if dim == 2: + # the area of a circle with r=1/3 is 0.34907 + volumetric_factor = 0.34907 + elif dim == 3: + # the volume of a sphere with r=1/3 is 0.15514 + volumetric_factor = 0.15514 - def neighbor_fn(position_and_overflow, max_occupancy=None): - position, overflow = position_and_overflow - N = position.shape[0] + num_edges_sub = int( + N_sub * considered_neighbors * volumetric_factor * capacity_multiplier + ) - if neighbors is None: # cl.shape = (nx, ny, nz, cell_capacity, dim) - cl = cl_fn.allocate(position, extra_capacity=extra_capacity) - else: - cl = cl_fn.update(position, neighbors.cell_list_capacity) - overflow = overflow | cl.did_buffer_overflow - cl_capacity = cl.cell_capacity + def scan_body(carry, input): + """Compute neighbors over a subset of particles - if num_partitions == 1: - implementation = "original" - elif num_partitions > 1: - implementation = ( - "numcells" # "numcells", "twentyseven", "vanilla", "original" - ) + The largest object here is of size (N_sub*considered_neighbors), where + considered_neighbors in 3D is 27 * cell_capacity. + """ - if implementation == "numcells": - # idx = cell_list_candidate_fn(cl, position) - # # idx.shape = (N, 27*cell_capacity) - # print("82 ", get_gpu_stats()) - # idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) - ################################################################ - - N, dim = position.shape - - idx = cl.id_buffer - - cell_idx = [idx] - - for dindex in _neighboring_cells( - dim - ): # here the expansion happens over all adjacent cells happens - if onp.all(dindex == 0): - continue - cell_idx += [ - shift_array(idx, dindex) - ] # 27* (nx,ny,nz,cell_capacity, 1) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) - cell_idx = jnp.reshape( - cell_idx, (-1, cell_idx.shape[-2]) - ) # (num_cells, num_potential_connections) - num_cells, considered_neighbors = cell_idx.shape - - # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # Given is a cell list `cell_idx` of shape (nx, ny, nz, cell_capacity). - # Find which cell indices correspond to particle 0, 1, 2, ..., N-1 - # and write the results into a new array of shape (N, nx, ny, nz) - - def scan_body(carry, input): - occupancy = carry - slice_from = input - - _entries = lax.dynamic_slice( - particle_cells, (slice_from,), (N_sub,) - ) - _idx = cell_idx[_entries] + occupancy = carry + slice_from = input - if mask_self: - particle_idx = slice_from + jnp.arange(N_sub) - _idx = jnp.where(_idx == particle_idx[:, None], N, _idx) + _entries = lax.dynamic_slice(particle_cells, (slice_from,), (N_sub,)) + _idx = cell_idx[_entries] - if num_pad > 0: - _idx = jnp.where(_entries[:, None] != -1, _idx, N) + if mask_self: + particle_idx = slice_from + jnp.arange(N_sub) + _idx = jnp.where(_idx == particle_idx[:, None], N, _idx) - sender_idx = ( - jnp.broadcast_to( - jnp.arange(N_sub, dtype="int32")[:, None], _idx.shape - ) - + slice_from - ) - if num_pad > 0: - sender_idx = jnp.clip(sender_idx, a_max=N) - - sender_idx = jnp.reshape( - sender_idx, (-1,) - ) # (N, 27*cell_capacity) -> (N*27*cell_capacity) - receiver_idx = jnp.reshape(_idx, (-1,)) - dR = d( - position[sender_idx], position[receiver_idx] - ) # (N*27*cell_capacity) eventually 3x during computation - - mask = (dR < cutoff_sq) & (receiver_idx < N) - out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) - - cumsum = jnp.cumsum(mask) # + occupancy - index = jnp.where( - mask, cumsum - 1, considered_neighbors * N - 1 - ) # (N*27*cell_capacity) - receiver_idx = out_idx.at[index].set(receiver_idx) - sender_idx = out_idx.at[index].set(sender_idx) - occupancy += cumsum[-1] - - carry = occupancy - y = jnp.stack( - (receiver_idx[:num_edges_sub], sender_idx[:num_edges_sub]) + if num_pad > 0: + _idx = jnp.where(_entries[:, None] != -1, _idx, N) + + sender_idx = ( + jnp.broadcast_to( + jnp.arange(N_sub, dtype="int32")[:, None], _idx.shape ) - overflow = cumsum[-1] > num_edges_sub - return carry, (y, overflow) - - particle_cells = get_particle_cells(idx, cl_capacity, N) - - d = partial(metric_sq, **kwargs) - d = space.map_bond(d) - - N_sub = int( - np.ceil(N / num_partitions) - ) # to pad the last chunk with < num_partitions entries - num_pad = N_sub * num_partitions - N - particle_cells = jnp.pad( - particle_cells, - ( - 0, - num_pad, - ), - constant_values=-1, + + slice_from ) + if num_pad > 0: + sender_idx = jnp.clip(sender_idx, a_max=N) - if dim == 2: - # area of a circle with r=1/3 is 0.15514 of a unit cube volume - volumetric_factor = 0.34907 - elif dim == 3: - # volume of sphere with r=1/3 is 0.15514 of a unit cube volume - volumetric_factor = 0.15514 - - num_edges_sub = int( - N_sub - * considered_neighbors - * volumetric_factor - * capacity_multiplier - ) + sender_idx = jnp.reshape(sender_idx, (-1,)) + receiver_idx = jnp.reshape(_idx, (-1,)) + dR = d(position[sender_idx], position[receiver_idx]) - carry = jnp.array(0) - xs = jnp.array([i * N_sub for i in range(num_partitions)]) - # print("82 (numcells)", get_gpu_stats()) - occupancy, (idx, overflows) = lax.scan( - scan_body, carry, xs, length=num_partitions - ) - # print("83 ", get_gpu_stats()) - overflow = overflow | overflows.sum() + mask = (dR < cutoff_sq) & (receiver_idx < N) + out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) - # print(f"idx memory: {idx.nbytes / 1e6:.0f}MB, idx.shape={idx.shape}, - # cl.id_buffer.shape={cl.id_buffer.shape}" ) - idx = idx.transpose(1, 2, 0).reshape(2, -1) + cumsum = jnp.cumsum(mask) + index = jnp.where(mask, cumsum - 1, considered_neighbors * N - 1) + receiver_idx = out_idx.at[index].set(receiver_idx) + sender_idx = out_idx.at[index].set(sender_idx) + occupancy += cumsum[-1] - # sort to enable pruning later - ordering = jnp.argsort(idx[1]) - idx = idx[:, ordering] - - if max_occupancy is None: - _extra_capacity = N * extra_capacity - max_occupancy = int( - occupancy * capacity_multiplier + _extra_capacity - ) - if max_occupancy > idx.shape[-1]: - max_occupancy = idx.shape[-1] - if not is_sparse(format): - capacity_limit = N - 1 if mask_self else N - elif format is NeighborListFormat.Sparse: - capacity_limit = N * (N - 1) if mask_self else N**2 - else: - capacity_limit = N * (N - 1) // 2 - if max_occupancy > capacity_limit: - max_occupancy = capacity_limit - - # prune neighbors list to max_occupancy by removing paddings - idx = idx[:, :max_occupancy] - elif implementation == "original_expanded": - # TODO: here we expand on the 27 adjacent cells - #################################################################### - ### - # idx = cell_list_candidate_fn(cl, position) - # # shape (N, 27*cell_capacity) -> 19M too much! - N, dim = position.shape - - idx = cl.id_buffer # (5, 5, 5, 88, 1) - - cell_idx = [idx] - - for dindex in _neighboring_cells( - dim - ): # here the expansion happens over all adjacent cells happens - if onp.all(dindex == 0): - continue - cell_idx += [ - shift_array(idx, dindex) - ] # 27* (nx,ny,nz,cell_capacity) - - cell_idx = jnp.concatenate(cell_idx, axis=-2) # (5, 5, 5, 2376, 1) - cell_idx = cell_idx[..., jnp.newaxis, :, :] # (5, 5, 5, 1, 2376, 1) - # TODO: memory blows up here by factor "cell_capacity" - cell_idx = jnp.broadcast_to( - cell_idx, idx.shape[:-1] + cell_idx.shape[-2:] - ) # 1.2*X (nx,ny,nz,cell_capacity,27*cell_capacity) - - # def copy_values_from_cell(value, cell_value, cell_id): - # scatter_indices = jnp.reshape(cell_id, (-1,)) - # cell_value = jnp.reshape(cell_value, (-1,) +cell_value.shape[-2:]) - # return value.at[scatter_indices].set(cell_value) - # TODO: further memory increase in the next two lines - neighbor_idx = jnp.zeros( - (N + 1,) + cell_idx.shape[-2:], jnp.int32 - ) # X (N, 27*cell_capacity, 1) - # neighbor_idx = copy_values_from_cell(neighbor_idx, cell_idx, idx) # X - scatter_indices = jnp.reshape( - idx, (-1,) - ) # (11000,) each cell allocation over all cells expanded - cell_value = jnp.reshape( - cell_idx, (-1,) + cell_idx.shape[-2:] - ) # (nx*ny*nz*cell_capacity, 27*cell_capacity, 1) - neighbor_idx = neighbor_idx.at[scatter_indices].set( - cell_value - ) # X (N, 27*cell_capacity, 1) - - idx = neighbor_idx[ - :-1, :, 0 - ] # X shape (N, 27*cell_capacity) this only removes the 8001th element - # this is just expanded over all cells indices. Should work with - # arbitrary pices over the last dimension - - #################################################################### - # idx.shape = (nx*ny*nz*cell_capacity**2*27) - # -> 26M (or actually just 19M) too much! - # idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) - d = partial(metric_sq, **kwargs) - d = space.map_bond(d) - - N = position.shape[0] - sender_idx = jnp.broadcast_to( - jnp.arange(N)[:, None], idx.shape - ) # 2X (N, 27*cell_capacity) - - sender_idx = jnp.reshape( - sender_idx, (-1,) - ) # (N, 27*cell_capacity) -> (N*27*cell_capacity) - # [0,0,0,0...0, 1,1,1,1...1, ....] - receiver_idx = jnp.reshape( - idx, (-1,) - ) # flatten the stuff with all possible neighbors (27*cell_size) of - # particle 0, of 1, .... - dR = d( - position[sender_idx], position[receiver_idx] - ) # (N*27*cell_capacity) eventually 3x during computation - - mask = (dR < cutoff_sq) & (receiver_idx < N) # negligible - # if format is NeighborListFormat.OrderedSparse: - # mask = mask & (receiver_idx < sender_idx) - - out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) # X - cumsum = jnp.cumsum(mask) # 2X - index = jnp.where( - mask, cumsum - 1, len(receiver_idx) - 1 - ) # 2X 7th object of shape (N*27*cell_capacity) - receiver_idx = out_idx.at[index].set( - receiver_idx - ) # X # this operation sorts the entries - sender_idx = out_idx.at[index].set( - sender_idx - ) # 2X -> X # this operation also sorts the entries - max_occupancy_ = cumsum[-1] - - idx, occupancy = jnp.stack((receiver_idx, sender_idx)), max_occupancy_ - # Memory: idx 2X, neihbor_idx X, cell_idx 0.8X, sender_idx X, - # receiver_idx X, dR 2X, out_idx X, cumsum 2X, index 2X - # idx_final = jnp.zeros((N, max_occupancy), jnp.int32) # X -> 2X - # print("max occupancy2 ", occupancy) - - if max_occupancy is None: - _extra_capacity = N * extra_capacity - max_occupancy = int( - occupancy * capacity_multiplier + _extra_capacity - ) - if max_occupancy > idx.shape[-1]: - max_occupancy = idx.shape[-1] - if not is_sparse(format): - capacity_limit = N - 1 if mask_self else N - elif format is NeighborListFormat.Sparse: - capacity_limit = N * (N - 1) if mask_self else N**2 - else: - capacity_limit = N * (N - 1) // 2 - if max_occupancy > capacity_limit: - max_occupancy = capacity_limit - idx = idx[ - :, :max_occupancy - ] # shape (N, max_occupancy) -> 2M much smaller - # TODO: from here on the size is ~10x smaller after - # idx=idx[:, :max_occupancy] - # how can we run the previous part sequentially? - ### - #################################################################### - elif implementation == "original": - # print("82 (original)", get_gpu_stats()) - idx = cell_list_candidate_fn(cl, position) - - idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) - - if max_occupancy is None: - _extra_capacity = ( - extra_capacity if not is_sparse(format) else N * extra_capacity - ) - max_occupancy = int( - occupancy * capacity_multiplier + _extra_capacity - ) - if max_occupancy > idx.shape[-1]: - max_occupancy = idx.shape[-1] - if not is_sparse(format): - capacity_limit = N - 1 if mask_self else N - elif format is NeighborListFormat.Sparse: - capacity_limit = N * (N - 1) if mask_self else N**2 - else: - capacity_limit = N * (N - 1) // 2 - if max_occupancy > capacity_limit: - max_occupancy = capacity_limit - idx = idx[:, :max_occupancy] - - # print("83 ", get_gpu_stats()) - # print(f"idx memory: {idx.nbytes / 1e6:.0f}MB, idx.shape={idx.shape}, - # cl.id_buffer.shape={cl.id_buffer.shape}" ) - - # print("##### max occupancy", max_occupancy, "occupancy", occupancy) + carry = occupancy + y = jnp.stack( + (receiver_idx[:num_edges_sub], sender_idx[:num_edges_sub]) + ) + overflow = cumsum[-1] > num_edges_sub + return carry, (y, overflow) + carry = jnp.array(0) + xs = jnp.array([i * N_sub for i in range(num_partitions)]) + occupancy, (idx, overflows) = lax.scan( + scan_body, carry, xs, length=num_partitions + ) + err = err.update(PEC.CELL_LIST_OVERFLOW, overflows.sum()) + idx = idx.transpose(1, 2, 0).reshape(2, -1) + + # sort to enable pruning later + ordering = jnp.argsort(idx[1]) + idx = idx[:, ordering] + + if max_occupancy is None: + _extra_capacity = N * extra_capacity + max_occupancy = int(occupancy * capacity_multiplier + _extra_capacity) + if max_occupancy > idx.shape[-1]: + max_occupancy = idx.shape[-1] + if not is_sparse(format): + capacity_limit = N - 1 if mask_self else N + elif format is NeighborListFormat.Sparse: + capacity_limit = N * (N - 1) if mask_self else N**2 + else: + capacity_limit = N * (N - 1) // 2 + if max_occupancy > capacity_limit: + max_occupancy = capacity_limit + idx = idx[:, :max_occupancy] update_fn = neighbor_list_fn if neighbors is None else neighbors.update_fn return NeighborList( idx, position, - overflow | (occupancy > max_occupancy), + err.update(PEC.NEIGHBOR_LIST_OVERFLOW, occupancy > max_occupancy), cl_capacity, max_occupancy, format, + cell_size, + cl_fn, update_fn, ) # pytype: disable=wrong-arg-count + nbrs = neighbors if nbrs is None: - return neighbor_fn((position, False)) + return neighbor_fn((position, PartitionError(jnp.zeros((), jnp.uint8)))) neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy) @@ -597,7 +332,7 @@ def scan_body(carry, input): return lax.cond( jnp.any(d(position, nbrs.reference_position) > threshold_sq), - (position, nbrs.did_buffer_overflow), + (position, nbrs.error), neighbor_fn, nbrs, lambda x: x, @@ -646,19 +381,23 @@ def _matscipy_neighbor_list( else: pbc = np.asarray(pbc, dtype=bool) + dtype_idx = jnp.arange(0).dtype # just to get the correct dtype + def matscipy_wrapper(position, idx_shape, num_particles): position = position[:num_particles] + if position.shape[1] == 2: position = np.pad( position, ((0, 0), (0, 1)), mode="constant", constant_values=0.5 ) + edge_list = matscipy_nl( "ij", cutoff=r_cutoff, positions=position, cell=box_size, pbc=pbc ) - edge_list = np.asarray(edge_list, dtype=np.int32) + edge_list = np.asarray(edge_list, dtype=dtype_idx) if not mask_self: # add self connection, which matscipy does not do - self_connect = np.arange(num_particles, dtype=np.int32) + self_connect = np.arange(num_particles, dtype=dtype_idx) self_connect = np.array([self_connect, self_connect]) edge_list = np.concatenate((self_connect, edge_list), axis=-1) @@ -666,7 +405,7 @@ def matscipy_wrapper(position, idx_shape, num_particles): idx_new = np.asarray(edge_list[:, : idx_shape[1]]) buffer_overflow = np.array(True) else: - idx_new = np.ones(idx_shape, dtype=np.int32) * num_particles_max + idx_new = np.ones(idx_shape, dtype=dtype_idx) * num_particles_max idx_new[:, : edge_list.shape[1]] = edge_list buffer_overflow = np.array(False) @@ -686,10 +425,13 @@ def update_fn( idx, buffer_overflow = jax.pure_callback( matscipy_wrapper, shape_out, position, neighbors.idx.shape, num_particles ) + return NeighborList( idx, position, - buffer_overflow, + neighbors.error.update(PEC.NEIGHBOR_LIST_OVERFLOW, buffer_overflow), + None, + None, None, None, None, @@ -700,7 +442,6 @@ def allocate_fn( position: jnp.ndarray, extra_capacity: int = 0, **kwargs ) -> NeighborList: num_particles = kwargs["num_particles"] - position = position[:num_particles] if position.shape[1] == 2: @@ -711,24 +452,26 @@ def allocate_fn( edge_list = matscipy_nl( "ij", cutoff=r_cutoff, positions=position, cell=box_size, pbc=pbc ) - edge_list = np.asarray(edge_list, dtype=np.int32) + edge_list = jnp.asarray(edge_list, dtype=dtype_idx) if not mask_self: # add self connection, which matscipy does not do - self_connect = np.arange(num_particles, dtype=np.int32) - self_connect = np.array([self_connect, self_connect]) - edge_list = np.concatenate((self_connect, edge_list), axis=-1) + self_connect = jnp.arange(num_particles, dtype=dtype_idx) + self_connect = jnp.array([self_connect, self_connect]) + edge_list = jnp.concatenate((self_connect, edge_list), axis=-1) # in case this is a (2,M) pair list, we pad with N and capacity_multiplier factor = capacity_multiplier * num_particles_max / num_particles res = num_particles * jnp.ones( (2, round(edge_list.shape[1] * factor + extra_capacity)), - np.int32, + dtype_idx, ) res = res.at[:, : edge_list.shape[1]].set(edge_list) return NeighborList( res, position, - jnp.array(False), + PartitionError(jnp.zeros((), jnp.uint8)), + None, + None, None, None, None, @@ -761,14 +504,52 @@ def neighbor_list( num_partitions: int = 1, pbc: jnp.ndarray = None, ) -> NeighborFn: - """Neighbor lists wrapper. + """Neighbor lists wrapper. Its arguments are mainly based on the jax-md ones. Args: - backend: The backend to use. One of "jaxmd_vmap", "jaxmd_scan", "matscipy". - - - "jaxmd_vmap": Default jax-md neighbor list. Uses vmap. Fast. - - "jaxmd_scan": Modified jax-md neighbor list. Uses scan. Memory efficient. - - "matscipy": Matscipy neighbor list. Runs on cpu, allows dynamic shapes. + displacement: A function `d(R_a, R_b)` that computes the displacement + between pairs of points. + box_size: Either a float specifying the size of the box or an array of + shape `[spatial_dim]` specifying the box size in each spatial dimension. + r_cutoff: A scalar specifying the neighborhood radius. + dr_threshold: A scalar specifying the maximum distance particles can move + before rebuilding the neighbor list. + backend: The backend to use. Can be one of: 1) ``jaxmd_vmap`` - the default + jax-md neighbor list which vectorizes the computations. 2) ``jaxmd_scan`` - + a modified jax-md neighbor list which serializes the search into + ``num_partitions`` chunks to improve the memory efficiency. 3) ``matscipy`` + - a jit-able implementation with the matscipy neighbor list backend, which + runs on CPU and takes variable number of particles smaller or equal to + ``num_particles``. + capacity_multiplier: A floating point scalar specifying the fractional + increase in maximum neighborhood occupancy we allocate compared with the + maximum in the example positions. + disable_cell_list: An optional boolean. If set to `True` then the neighbor + list is constructed using only distances. This can be useful for + debugging but should generally be left as `False`. + mask_self: An optional boolean. Determines whether points can consider + themselves to be their own neighbors. + custom_mask_function: An optional function. Takes the neighbor array + and masks selected elements. Note: The input array to the function is + `(n_particles, m)` where the index of particle 1 is in index in the first + dimension of the array, the index of particle 2 is given by the value in + the array + fractional_coordinates: An optional boolean. Specifies whether positions will + be supplied in fractional coordinates in the unit cube, :math:`[0, 1]^d`. + If this is set to True then the `box_size` will be set to `1.0` and the + cell size used in the cell list will be set to `cutoff / box_size`. + format: The format of the neighbor list; see the :meth:`NeighborListFormat` enum + for details about the different choices for formats. Defaults to `Dense`. + num_particles_max: only used with the ``matscipy`` backend. Based + on the largest particles system in a dataset. + num_partitions: only used with the ``jaxmd_scan`` backend + pbc: only used with the ``matscipy`` backend. Defines the boundary conditions + for each dimension individually. Can have shape (2,) or (3,). + **static_kwargs: kwargs that get threaded through the calculation of + example positions. + Returns: + A NeighborListFns object that contains a method to allocate a new neighbor + list and a method to update an existing neighbor list. """ assert backend in BACKENDS, f"Unknown backend {backend}" diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py new file mode 100644 index 0000000..9e2ca71 --- /dev/null +++ b/tests/test_neighbors.py @@ -0,0 +1,137 @@ +import unittest + +import numpy as np +from jax import config + +config.update("jax_enable_x64", True) +import jax.numpy as jnp +from jax import jit +from jax_md import space + +from jax_sph import partition + + +@jit +def updater(nbrs_old, r_new, **kwargs): + nbrs_new = nbrs_old.update(r_new, **kwargs) + return nbrs_new + + +class BaseTest(unittest.TestCase): + def body(self, args, backend, num_partitions, verbose=False): + r = args["r"] + box_size = args["box_size"] + cutoff = args["cutoff"] + mask_self = args["mask_self"] + target = args["target"] + + if verbose: + print(f"Start with {backend} backend and {num_partitions} partition(s)") + + N, dim = r.shape + num_particles_max = r.shape[0] + displacement_fn, _ = space.periodic(side=box_size) + neighbor_fn = partition.neighbor_list( + displacement_fn, + box_size, + r_cutoff=cutoff, + backend=backend, + dr_threshold=0.0, + capacity_multiplier=1.25, + mask_self=mask_self, + format=partition.NeighborListFormat.Sparse, + num_particles_max=num_particles_max, + num_partitions=num_partitions, + pbc=np.array([True] * dim), + ) + + nbrs = neighbor_fn.allocate(r, num_particles=N) + + if backend == "matscipy": + nbrs2 = updater(nbrs_old=nbrs, r_new=r, num_particles=N) + else: + nbrs2 = updater(nbrs, r) + mask_real = nbrs.idx[0] < N + idx_real = nbrs.idx[:, mask_real] + + if verbose: + print("Idx: \n", nbrs.idx) + print("Idx_real: \n", idx_real) + + self.assertFalse(nbrs.did_buffer_overflow, "Buffer overflow (allocate)") + self.assertFalse(nbrs2.did_buffer_overflow, "Buffer overflow (update)") + + self.assertTrue((nbrs.idx == nbrs2.idx).all(), "allocate differes from update") + + self.assertTrue( + ((nbrs.idx[0] == N) == (nbrs.idx[1] == N)).all(), "One sided edges" + ) + + self_edges_mask = idx_real[0] == idx_real[1] + if mask_self: + self.assertEqual(sum(self_edges_mask), 0.0, "Self edges b/n real particles") + else: + self_edges = idx_real[:, self_edges_mask] + self.assertEqual(len(np.unique(self_edges[0])), N, "Self edges are broken") + + # sorted edge list based on second edge row (first sort by first row) + sort_idx = np.argsort(idx_real[0]) + idx_real_sorted = idx_real[:, sort_idx] + sort_idx = np.argsort(idx_real_sorted[1]) + idx_real_sorted = idx_real_sorted[:, sort_idx] + self.assertTrue((idx_real_sorted == target).all(), "Wrong edge list") + + if verbose: + print(f"Finish with {backend} backend and {num_partitions} partition(s)") + + def cases(self, backend, num_partitions=1, verbose=False): + # Simple test with pbc and with/without self-masking + args = { + "mask_self": False, + "cutoff": 0.33, + "box_size": np.array([1.0, 1.0]), + "r": jnp.array([[0.1, 0.1], [0.1, 0.3], [0.1, 0.9], [0.6, 0.5]]), + "target": jnp.array([[0, 1, 2, 0, 1, 0, 2, 3], [0, 0, 0, 1, 1, 2, 2, 3]]), + } + self.body(args, backend, num_partitions, verbose) + + args["mask_self"] = True + args["target"] = jnp.array([[1, 2, 0, 0], [0, 0, 1, 2]]) + self.body(args, backend, num_partitions, verbose) + + # Edge case at which the scan implementation almost breaks + args = { + "mask_self": False, + "cutoff": 0.33, + "box_size": np.array([1.0, 1.0]), + "r": jnp.array( + [[0.5, 0.2], [0.2, 0.5], [0.5, 0.5], [0.8, 0.5], [0.5, 0.8]] + ), + "target": jnp.array( + [ + [0, 2, 1, 2, 0, 1, 2, 3, 4, 2, 3, 2, 4], + [0, 0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4], + ] + ), + } + self.body(args, backend, num_partitions, verbose) + + args["mask_self"] = True + args["target"] = jnp.array([[2, 2, 0, 1, 3, 4, 2, 2], [0, 1, 2, 2, 2, 2, 3, 4]]) + self.body(args, backend, num_partitions, verbose) + + def test_vmap(self): + self.cases("jaxmd_vmap") + + def test_scan1(self): + self.cases("jaxmd_scan") + + def test_scan2(self): + self.cases("jaxmd_scan", 2) + + def test_matscipy(self): + self.cases("matscipy") + + +if __name__ == "__main__": + unittest.main()