Skip to content

Commit

Permalink
test and fix nb mat
Browse files Browse the repository at this point in the history
  • Loading branch information
zubatyuk committed Oct 30, 2024
1 parent e4c3a1c commit ed29ceb
Show file tree
Hide file tree
Showing 11 changed files with 14,290 additions and 244 deletions.
19 changes: 12 additions & 7 deletions aimnet/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor, nn

from .model_registry import get_model_path
from .nbmat import calc_nbmat
from .nbmat import TooManyNeighborsError, calc_nbmat


class AIMNet2Calculator:
Expand Down Expand Up @@ -76,7 +76,7 @@ def set_lrcoulomb_method(
mod.dsf_rc = cutoff # type: ignore
elif method == "ewald":
# current implementaion of Ewald does not use nb mat
self.cutoff_lr = None
self.cutoff_lr = cutoff
self._coulomb_method = method

def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]:
Expand Down Expand Up @@ -178,8 +178,11 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:

while True:
try:
maxnb1 = min(calc_max_nb(self.cutoff, self.max_density), self._max_mol_size)
maxnb2 = min(calc_max_nb(self.cutoff_lr, self.max_density), self._max_mol_size) if self.lr else None # type: ignore
maxnb1 = calc_max_nb(self.cutoff, self.max_density)
maxnb2 = calc_max_nb(self.cutoff_lr, self.max_density) if self.lr else None # type: ignore
if cell is None:
maxnb1 = min(maxnb1, self._max_mol_size)
maxnb2 = min(maxnb2, self._max_mol_size) if self.lr else None # type: ignore
maxnb = (maxnb1, maxnb2)
nbmat1, nbmat2, shifts1, shifts2 = calc_nbmat(
data["coord"],
Expand All @@ -189,8 +192,9 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
data.get("mol_idx"), # type: ignore
)
break
except ValueError:
except TooManyNeighborsError:
self.max_density *= 1.2
assert self.max_density <= 4, "Something went wrong in nbmat calculation"
data["nbmat"] = nbmat1
if self.lr:
assert nbmat2 is not None
Expand All @@ -205,9 +209,10 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:

def pad_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
N = data["nbmat"].shape[0]
data["coord"] = maybe_pad_dim0(data["coord"], N)
data["numbers"] = maybe_pad_dim0(data["numbers"], N)
data["mol_idx"] = maybe_pad_dim0(data["mol_idx"], N, value=data["mol_idx"][-1].item())
for k in ("coord", "numbers"):
if k in data:
data[k] = maybe_pad_dim0(data[k], N)
return data

def unpad_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
Expand Down
137 changes: 80 additions & 57 deletions aimnet/calculators/nb_kernel_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


@numba.njit(cache=True, parallel=True, fastmath=True)
@numba.njit(cache=True, parallel=False, fastmath=True)
def _nbmat_dual_cpu(
coord: np.ndarray, # float, (N, 3)
cutoff1_squared: float,
Expand All @@ -17,7 +17,7 @@ def _nbmat_dual_cpu(
maxnb1 = nbmat1.shape[1]
maxnb2 = nbmat2.shape[1]
N = coord.shape[0]
for i in numba.prange(N):
for i in range(N):
c_i = coord[i]
_mol_idx = mol_idx[i]
_j_start = i + 1
Expand All @@ -40,7 +40,7 @@ def _nbmat_dual_cpu(
_expand_nb(nnb2, nbmat2)


@numba.njit(cache=True, parallel=True, fastmath=True)
@numba.njit(cache=True, parallel=False, fastmath=True)
def _nbmat_cpu(
coord: np.ndarray, # float, (N, 3)
cutoff1_squared: float,
Expand All @@ -51,7 +51,7 @@ def _nbmat_cpu(
):
maxnb1 = nbmat1.shape[1]
N = coord.shape[0]
for i in numba.prange(N):
for i in range(N):
c_i = coord[i]
_mol_idx = mol_idx[i]
_j_start = i + 1
Expand All @@ -74,11 +74,14 @@ def _expand_nb(nnb, nbmat):
N = nnb.shape[0]
for i in range(N):
for m in range(nnb_copy[i]):
if m >= nbmat.shape[1]:
continue
j = nbmat[i, m]
pos = nnb[j]
nnb[j] += 1
if pos < nbmat.shape[1]:
nbmat[j, pos] = i
if j < N:
pos = nnb[j]
nnb[j] += 1
if pos < nbmat.shape[1]:
nbmat[j, pos] = i


@numba.njit(cache=True, inline="always")
Expand All @@ -87,19 +90,22 @@ def _expand_nb_pbc(nnb, nbmat, shifts):
N = nnb.shape[0]
for i in range(N):
for m in range(nnb_copy[i]):
if m >= nbmat.shape[1]:
continue
j = nbmat[i, m]
pos = nnb[j]
nnb[j] += 1
if pos < nbmat.shape[1]:
nbmat[j, pos] = i
shift = shifts[i, m]
shifts[j, pos] = -shift
if j < N:
pos = nnb[j]
nnb[j] += 1
if pos < nbmat.shape[1]:
nbmat[j, pos] = i
shift = shifts[i, m]
shifts[j, pos] = -shift


@numba.njit(cache=True)
def _expand_shifts(nshift):
tot_shifts = ((nshift[0] + 1) * (2 * nshift[1] + 1) * (2 * nshift[2] + 1)).sum()
shifts = np.zeros((tot_shifts, 3), dtype=np.int8)
tot_shifts = (nshift[0] + 1) * (2 * nshift[1] + 1) * (2 * nshift[2] + 1)
shifts = np.zeros((tot_shifts, 3), dtype=np.float32)
i = 0
for k1 in range(-nshift[0], nshift[0] + 1):
for k2 in range(-nshift[1], nshift[1] + 1):
Expand All @@ -113,87 +119,104 @@ def _expand_shifts(nshift):
return shifts


@numba.njit(cache=True, parallel=True, fastmath=True)
def _nbmat_dual_pbc_cpu(
@numba.njit(cache=True, parallel=False, fastmath=True)
def shift_coords(coord, cell, shifts):
N = coord.shape[0]
S = shifts.shape[0]
# pre-compute shifted coords
coord_shifted = np.empty((N, S, 3), dtype=coord.dtype)
for i in range(N):
for s in range(S):
shift = shifts[s]
c_x = coord[i, 0] + shift[0] * cell[0, 0] + shift[1] * cell[1, 0] + shift[2] * cell[2, 0]
c_y = coord[i, 1] + shift[0] * cell[0, 1] + shift[1] * cell[1, 1] + shift[2] * cell[2, 1]
c_z = coord[i, 2] + shift[0] * cell[0, 2] + shift[1] * cell[1, 2] + shift[2] * cell[2, 2]
coord_shifted[i, s] = c_x, c_y, c_z
return coord_shifted


@numba.njit(cache=True, parallel=False, fastmath=True)
def _nbmat_pbc_cpu(
coord: np.ndarray, # float, (N, 3)
cell: np.ndarray, # float, (3, 3)
cutoff1_squared: float,
cutoff2_squared: float,
shifts: np.ndarray, # float, (S, 3)
nnb1: np.ndarray, # int, zeros, (N,)
nnb2: np.ndarray, # int, zeros, (N,)
nbmat1: np.ndarray, # int, (N, M)
nbmat2: np.ndarray, # int, (N, K)
shifts1: np.ndarray, # int, (N, M, 3)
shifts2: np.ndarray, # int (N, K, 3)
):
maxnb1 = nbmat1.shape[1]
maxnb2 = nbmat2.shape[1]
N = coord.shape[0]
S = shifts.shape[0]

for i in numba.prange(N):
coord_shifted = shift_coords(coord, cell, shifts)

for i in range(N):
c_i = coord[i]
for s in range(S):
shift = shifts[s]
zero_shift = np.all(shift == 0)
c_i = coord[i] if zero_shift else coord[i] + shift @ cell
_j_start = i + 1 if zero_shift else 0
_j_end = N
for j in range(_j_start, _j_end):
c_j = coord[j]
diff = c_i - c_j
# hint for numba to vectorize the op
dx, dy, dz = diff[0], diff[1], diff[2]
dist2 = dx * dx + dy * dy + dz * dz
if dist2 < cutoff1_squared:
zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0
_j_end = i if zero_shift else N
for j in range(_j_end):
c_j = coord_shifted[j, s]
dx = c_i[0] - c_j[0]
dy = c_i[1] - c_j[1]
dz = c_i[2] - c_j[2]
r2 = dx * dx + dy * dy + dz * dz
if r2 < cutoff1_squared:
pos = nnb1[i]
nnb1[i] += 1
if pos < maxnb1:
nbmat1[i, pos] = j
shifts1[i, pos] = shift
if dist2 < cutoff2_squared:
pos = nnb2[i]
nnb2[i] += 1
if pos < maxnb2:
nbmat2[i, pos] = j
shifts2[i, pos] = shift

_expand_nb_pbc(nnb1, nbmat1, shifts1)
_expand_nb_pbc(nnb2, nbmat2, shifts2)


@numba.njit(cache=True, parallel=True, fastmath=True)
def _nbmat_pbc_cpu(
@numba.njit(cache=True, parallel=False, fastmath=True)
def _nbmat_dual_pbc_cpu(
coord: np.ndarray, # float, (N, 3)
cell: np.ndarray, # float, (3, 3)
cutoff1_squared: float,
cutoff2_squared: float,
shifts: np.ndarray, # float, (S, 3)
nnb1: np.ndarray, # int, zeros, (N,)
nnb2: np.ndarray, # int, zeros, (N,)
nbmat1: np.ndarray, # int, (N, M)
nbmat2: np.ndarray, # int, (N, M)
shifts1: np.ndarray, # int, (N, M, 3)
shifts2: np.ndarray, # int, (N, M, 3)
):
maxnb1 = nbmat1.shape[1]
maxnb2 = nbmat2.shape[1]
N = coord.shape[0]
S = shifts.shape[0]

for i in numba.prange(N):
coord_shifted = shift_coords(coord, cell, shifts)

for i in range(N):
c_i = coord[i]
for s in range(S):
shift = shifts[s]
zero_shift = np.all(shift == 0)
c_i = coord[i] if zero_shift else coord[i] + shift @ cell
_j_start = i + 1 if zero_shift else 0
_j_end = N
for j in range(_j_start, _j_end):
c_j = coord[j]
diff = c_i - c_j
# hint for numba to vectorize the op
dx, dy, dz = diff[0], diff[1], diff[2]
dist2 = dx * dx + dy * dy + dz * dz
if dist2 < cutoff1_squared:
zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0
_j_end = i if zero_shift else N
for j in range(_j_end):
c_j = coord_shifted[j, s]
dx = c_i[0] - c_j[0]
dy = c_i[1] - c_j[1]
dz = c_i[2] - c_j[2]
r2 = dx * dx + dy * dy + dz * dz
if r2 < cutoff1_squared:
pos = nnb1[i]
nnb1[i] += 1
if pos < maxnb1:
nbmat1[i, pos] = j
shifts1[i, pos] = shift
if r2 < cutoff2_squared:
pos = nnb2[i]
nnb2[i] += 1
if pos < maxnb2:
nbmat2[i, pos] = j
shifts2[i, pos] = shift

_expand_nb_pbc(nnb1, nbmat1, shifts1)
_expand_nb_pbc(nnb2, nbmat2, shifts2)
24 changes: 7 additions & 17 deletions aimnet/calculators/nb_kernel_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _nbmat_pbc_dual_cuda(
return

maxnb1 = nbmat1.shape[1]
maxnb2 = nbmat2.shape[2]
maxnb2 = nbmat2.shape[1]

shift_x = shifts[shift_idx, 0]
shift_y = shifts[shift_idx, 1]
Expand All @@ -114,14 +114,9 @@ def _nbmat_pbc_dual_cuda(
shift_y = numba.float32(shift_y)
shift_z = numba.float32(shift_z)

if zero_shift:
coord_shifted_x = coord[atom_idx, 0]
coord_shifted_y = coord[atom_idx, 1]
coord_shifted_z = coord[atom_idx, 2]
else:
coord_shifted_x = coord[atom_idx, 0] + shift_x * cell[0, 0] + shift_y * cell[1, 0] + shift_z * cell[2, 0]
coord_shifted_y = coord[atom_idx, 1] + shift_x * cell[0, 1] + shift_y * cell[1, 1] + shift_z * cell[2, 1]
coord_shifted_z = coord[atom_idx, 2] + shift_x * cell[0, 2] + shift_y * cell[1, 2] + shift_z * cell[2, 2]
coord_shifted_x = coord[atom_idx, 0] + shift_x * cell[0, 0] + shift_y * cell[1, 0] + shift_z * cell[2, 0]
coord_shifted_y = coord[atom_idx, 1] + shift_x * cell[0, 1] + shift_y * cell[1, 1] + shift_z * cell[2, 1]
coord_shifted_z = coord[atom_idx, 2] + shift_x * cell[0, 2] + shift_y * cell[1, 2] + shift_z * cell[2, 2]

for i in range(_n):
if zero_shift and i >= atom_idx:
Expand Down Expand Up @@ -193,14 +188,9 @@ def _nbmat_pbc_cuda(
shift_y = numba.float32(shift_y)
shift_z = numba.float32(shift_z)

if zero_shift:
coord_shifted_x = coord[atom_idx, 0]
coord_shifted_y = coord[atom_idx, 1]
coord_shifted_z = coord[atom_idx, 2]
else:
coord_shifted_x = coord[atom_idx, 0] + shift_x * cell[0, 0] + shift_y * cell[1, 0] + shift_z * cell[2, 0]
coord_shifted_y = coord[atom_idx, 1] + shift_x * cell[0, 1] + shift_y * cell[1, 1] + shift_z * cell[2, 1]
coord_shifted_z = coord[atom_idx, 2] + shift_x * cell[0, 2] + shift_y * cell[1, 2] + shift_z * cell[2, 2]
coord_shifted_x = coord[atom_idx, 0] + shift_x * cell[0, 0] + shift_y * cell[1, 0] + shift_z * cell[2, 0]
coord_shifted_y = coord[atom_idx, 1] + shift_x * cell[0, 1] + shift_y * cell[1, 1] + shift_z * cell[2, 1]
coord_shifted_z = coord[atom_idx, 2] + shift_x * cell[0, 2] + shift_y * cell[1, 2] + shift_z * cell[2, 2]

for i in range(_n):
if zero_shift and i >= atom_idx:
Expand Down
Loading

0 comments on commit ed29ceb

Please sign in to comment.