diff --git a/src/pytorch/neighbors/TestNeighbors.py b/src/pytorch/neighbors/TestNeighbors.py index 9b3fbad..2b1365a 100644 --- a/src/pytorch/neighbors/TestNeighbors.py +++ b/src/pytorch/neighbors/TestNeighbors.py @@ -128,11 +128,14 @@ def test_neighbor_grads(dtype, num_atoms, grad): raise ValueError('grad') if dtype == pt.float32: - assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-5, rtol=1e-3) + assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-3, rtol=1e-3) else: assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-8, rtol=1e-5) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +# The following test is only run on the CPU. Running it on the GPU triggers a +# CUDA assertion, which causes all tests run after it to fail. + +@pytest.mark.parametrize('device', ['cpu']) @pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) def test_too_many_neighbors(device, dtype): @@ -143,4 +146,67 @@ def test_too_many_neighbors(device, dtype): with pytest.raises(RuntimeError): positions = pt.zeros((4, 3,), device=device, dtype=dtype) getNeighborPairs(positions, cutoff=1, max_num_neighbors=1) - pt.cuda.synchronize() \ No newline at end of file + pt.cuda.synchronize() + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) +def test_periodic_neighbors(device, dtype): + + if not pt.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + + # Generate random positions + num_atoms = 100 + positions = (20 * pt.randn((num_atoms, 3), device=device, dtype=dtype)) - 10 + box_vectors = pt.tensor([[10, 0, 0], [2, 12, 0], [0, 1, 11]], device=device, dtype=dtype) + cutoff = 5.0 + + # Get neighbor pairs + ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1)) + ref_positions = positions.cpu().numpy() + ref_vectors = box_vectors.cpu().numpy() + ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]] + ref_deltas -= np.outer(np.round(ref_deltas[:,2]/ref_vectors[2,2]), ref_vectors[2]) + ref_deltas -= np.outer(np.round(ref_deltas[:,1]/ref_vectors[1,1]), ref_vectors[1]) + ref_deltas -= np.outer(np.round(ref_deltas[:,0]/ref_vectors[0,0]), ref_vectors[0]) + ref_distances = np.linalg.norm(ref_deltas, axis=1) + + # Filter the neighbor pairs + mask = ref_distances > cutoff + ref_neighbors[:, mask] = -1 + ref_deltas[mask, :] = np.nan + ref_distances[mask] = np.nan + + # Find the number of neighbors + num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) + max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1) + + # Compute results + neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors) + + # Check device + assert neighbors.device == positions.device + assert deltas.device == positions.device + assert distances.device == positions.device + + # Check types + assert neighbors.dtype == pt.int32 + assert deltas.dtype == dtype + assert distances.dtype == dtype + + # Covert the results + neighbors = neighbors.cpu().numpy() + deltas = deltas.cpu().numpy() + distances = distances.cpu().numpy() + + # Sort the neighbors + # NOTE: GPU returns the neighbor in a non-deterministic order + ref_neighbors, ref_deltas, ref_distances = sort_neighbors(ref_neighbors, ref_deltas, ref_distances) + neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances) + + # Resize the reference + ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors) + + assert np.all(ref_neighbors == neighbors) + assert np.allclose(ref_deltas, deltas, equal_nan=True) + assert np.allclose(ref_distances, distances, equal_nan=True) diff --git a/src/pytorch/neighbors/getNeighborPairs.py b/src/pytorch/neighbors/getNeighborPairs.py index c6993b1..55d20d4 100644 --- a/src/pytorch/neighbors/getNeighborPairs.py +++ b/src/pytorch/neighbors/getNeighborPairs.py @@ -1,8 +1,8 @@ -from torch import ops, Tensor -from typing import Tuple +from torch import empty, ops, Tensor +from typing import Optional, Tuple -def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1) -> Tuple[Tensor, Tensor]: +def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ''' Returns indices and distances of atom pairs within a given cutoff distance. @@ -16,6 +16,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = molecule, where most of the atoms are beyond the cutoff distance of each other. + This function optionally supports periodic boundary conditions with + arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy + certain requirements: + + `a[1] = a[2] = b[2] = 0` + `a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff` + `a[0] >= 2*b[0]` + `a[0] >= 2*c[0]` + `b[1] >= 2*c[1]` + + These requirements correspond to a particular rotation of the system and + reduced form of the vectors, as well as the requirement that the cutoff be + no larger than half the box width. + Parameters ---------- positions: `torch.Tensor` @@ -26,6 +40,10 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = max_num_neighbors: int, optional Maximum number of neighbors per atom. If set to `-1` (default), all possible combinations of atom pairs are included. + box_vectors: `torch.Tensor`, optional + The vectors defining the periodic box. This must have shape `(3, 3)`, + where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. + If this is omitted, periodic boundary conditions are not applied. Returns ------- @@ -103,4 +121,6 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = tensor([1., 1., nan, nan, nan, nan])) ''' - return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors) \ No newline at end of file + if box_vectors is None: + box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype) + return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors) \ No newline at end of file diff --git a/src/pytorch/neighbors/getNeighborPairsCPU.cpp b/src/pytorch/neighbors/getNeighborPairsCPU.cpp index d9ba2f2..19dfa7d 100644 --- a/src/pytorch/neighbors/getNeighborPairsCPU.cpp +++ b/src/pytorch/neighbors/getNeighborPairsCPU.cpp @@ -13,10 +13,13 @@ using torch::Scalar; using torch::hstack; using torch::vstack; using torch::Tensor; +using torch::outer; +using torch::round; static tuple forward(const Tensor& positions, const Scalar& cutoff, - const Scalar& max_num_neighbors) { + const Scalar& max_num_neighbors, + const Tensor& box_vectors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); @@ -25,6 +28,25 @@ static tuple forward(const Tensor& positions, TORCH_CHECK(cutoff.to() > 0, "Expected \"cutoff\" to be positive"); + if (box_vectors.size(0) != 0) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)"); + double v[3][3]; + for (int i = 0; i < 3; i++) + for (int j = 0; j < 3; j++) + v[i][j] = box_vectors[i][j].item(); + double c = cutoff.to(); + TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0"); + TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0"); + TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0"); + TORCH_CHECK(v[0][0] >= 2*c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff"); + TORCH_CHECK(v[1][1] >= 2*c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff"); + TORCH_CHECK(v[2][2] >= 2*c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff"); + TORCH_CHECK(v[0][0] >= 2*v[1][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[0][0] >= 2*v[2][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); + TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); + } + const int max_num_neighbors_ = max_num_neighbors.to(); TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, "Expected \"max_num_neighbors\" to be positive or equal to -1"); @@ -39,12 +61,17 @@ static tuple forward(const Tensor& positions, Tensor neighbors = vstack({rows, columns}); Tensor deltas = index_select(positions, 0, rows) - index_select(positions, 0, columns); + if (box_vectors.size(0) != 0) { + deltas -= outer(round(deltas.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2})); + deltas -= outer(round(deltas.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1})); + deltas -= outer(round(deltas.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0})); + } Tensor distances = frobenius_norm(deltas, 1); if (max_num_neighbors_ == -1) { const Tensor mask = distances > cutoff; neighbors.index_put_({Slice(), mask}, -1); - deltas = deltas.clone(); // Brake an autograd loop + deltas = deltas.clone(); // Break an autograd loop deltas.index_put_({mask, Slice()}, NAN); distances.index_put_({mask}, NAN); diff --git a/src/pytorch/neighbors/getNeighborPairsCUDA.cu b/src/pytorch/neighbors/getNeighborPairsCUDA.cu index 33e0058..2d820a4 100644 --- a/src/pytorch/neighbors/getNeighborPairsCUDA.cu +++ b/src/pytorch/neighbors/getNeighborPairsCUDA.cu @@ -31,10 +31,12 @@ template __global__ void forward_kernel( const Accessor positions, const scalar_t cutoff2, const bool store_all_pairs, + const bool use_periodic, Accessor i_curr_pair, Accessor neighbors, Accessor deltas, - Accessor distances + Accessor distances, + Accessor box_vectors ) { const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= num_all_pairs) return; @@ -43,9 +45,20 @@ template __global__ void forward_kernel( if (row * (row - 1) > 2 * index) row--; const int32_t column = index - row * (row - 1) / 2; - const scalar_t delta_x = positions[row][0] - positions[column][0]; - const scalar_t delta_y = positions[row][1] - positions[column][1]; - const scalar_t delta_z = positions[row][2] - positions[column][2]; + scalar_t delta_x = positions[row][0] - positions[column][0]; + scalar_t delta_y = positions[row][1] - positions[column][1]; + scalar_t delta_z = positions[row][2] - positions[column][2]; + if (use_periodic) { + scalar_t scale3 = round(delta_z/box_vectors[2][2]); + delta_x -= scale3*box_vectors[2][0]; + delta_y -= scale3*box_vectors[2][1]; + delta_z -= scale3*box_vectors[2][2]; + scalar_t scale2 = round(delta_y/box_vectors[1][1]); + delta_x -= scale2*box_vectors[1][0]; + delta_y -= scale2*box_vectors[1][1]; + scalar_t scale1 = round(delta_x/box_vectors[0][0]); + delta_x -= scale1*box_vectors[0][0]; + } const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z; if (distance2 > cutoff2) return; @@ -89,7 +102,8 @@ public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Scalar& cutoff, - const Scalar& max_num_neighbors) { + const Scalar& max_num_neighbors, + const Tensor& box_vectors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); @@ -100,6 +114,12 @@ public: TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, "Expected \"max_num_neighbors\" to be positive or equal to -1"); + const bool use_periodic = (box_vectors.size(0) != 0); + if (use_periodic) { + TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions"); + TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)"); + } + // Decide the algorithm const bool store_all_pairs = max_num_neighbors_ == -1; const int num_atoms = positions.size(0); @@ -125,10 +145,12 @@ public: get_accessor(positions), cutoff_ * cutoff_, store_all_pairs, + use_periodic, get_accessor(i_curr_pair), get_accessor(neighbors), get_accessor(deltas), - get_accessor(distances)); + get_accessor(distances), + get_accessor(box_vectors)); }); ctx->save_for_backward({neighbors, deltas, distances}); @@ -165,14 +187,14 @@ public: get_accessor(grad_positions)); }); - return {grad_positions, Tensor(), Tensor()}; + return {grad_positions, Tensor(), Tensor(), Tensor()}; } }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { m.impl("getNeighborPairs", - [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors){ - const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors); + [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors, const Tensor& box_vectors){ + const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors, box_vectors); return make_tuple(results[0], results[1], results[2]); }); } \ No newline at end of file diff --git a/src/pytorch/neighbors/neighbors.cpp b/src/pytorch/neighbors/neighbors.cpp index 65e6af5..d8dd5c5 100644 --- a/src/pytorch/neighbors/neighbors.cpp +++ b/src/pytorch/neighbors/neighbors.cpp @@ -1,5 +1,5 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors) -> (Tensor neighbors, Tensor deltas, Tensor distances)"); + m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors) -> (Tensor neighbors, Tensor deltas, Tensor distances)"); } \ No newline at end of file