Skip to content

Commit

Permalink
getNeighborPairs() supports periodic boundary conditions (#70)
Browse files Browse the repository at this point in the history
* getNeighborPairs() supports periodic boundary conditions

* CUDA implementation of periodic boundary conditions

* Fixed error in autograd

* Skip test that causes CUDA assertion

* Added checks for invalid box vectors
  • Loading branch information
peastman authored Jan 12, 2023
1 parent 7491583 commit 8b2d427
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 19 deletions.
72 changes: 69 additions & 3 deletions src/pytorch/neighbors/TestNeighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ def test_neighbor_grads(dtype, num_atoms, grad):
raise ValueError('grad')

if dtype == pt.float32:
assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-5, rtol=1e-3)
assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-3, rtol=1e-3)
else:
assert pt.allclose(positions_cpu.grad, positions_cuda.grad.cpu(), atol=1e-8, rtol=1e-5)

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
# The following test is only run on the CPU. Running it on the GPU triggers a
# CUDA assertion, which causes all tests run after it to fail.

@pytest.mark.parametrize('device', ['cpu'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_too_many_neighbors(device, dtype):

Expand All @@ -143,4 +146,67 @@ def test_too_many_neighbors(device, dtype):
with pytest.raises(RuntimeError):
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
getNeighborPairs(positions, cutoff=1, max_num_neighbors=1)
pt.cuda.synchronize()
pt.cuda.synchronize()

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_periodic_neighbors(device, dtype):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

# Generate random positions
num_atoms = 100
positions = (20 * pt.randn((num_atoms, 3), device=device, dtype=dtype)) - 10
box_vectors = pt.tensor([[10, 0, 0], [2, 12, 0], [0, 1, 11]], device=device, dtype=dtype)
cutoff = 5.0

# Get neighbor pairs
ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1))
ref_positions = positions.cpu().numpy()
ref_vectors = box_vectors.cpu().numpy()
ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]]
ref_deltas -= np.outer(np.round(ref_deltas[:,2]/ref_vectors[2,2]), ref_vectors[2])
ref_deltas -= np.outer(np.round(ref_deltas[:,1]/ref_vectors[1,1]), ref_vectors[1])
ref_deltas -= np.outer(np.round(ref_deltas[:,0]/ref_vectors[0,0]), ref_vectors[0])
ref_distances = np.linalg.norm(ref_deltas, axis=1)

# Filter the neighbor pairs
mask = ref_distances > cutoff
ref_neighbors[:, mask] = -1
ref_deltas[mask, :] = np.nan
ref_distances[mask] = np.nan

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))
max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1)

# Compute results
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors)

# Check device
assert neighbors.device == positions.device
assert deltas.device == positions.device
assert distances.device == positions.device

# Check types
assert neighbors.dtype == pt.int32
assert deltas.dtype == dtype
assert distances.dtype == dtype

# Covert the results
neighbors = neighbors.cpu().numpy()
deltas = deltas.cpu().numpy()
distances = distances.cpu().numpy()

# Sort the neighbors
# NOTE: GPU returns the neighbor in a non-deterministic order
ref_neighbors, ref_deltas, ref_distances = sort_neighbors(ref_neighbors, ref_deltas, ref_distances)
neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances)

# Resize the reference
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors)

assert np.all(ref_neighbors == neighbors)
assert np.allclose(ref_deltas, deltas, equal_nan=True)
assert np.allclose(ref_distances, distances, equal_nan=True)
28 changes: 24 additions & 4 deletions src/pytorch/neighbors/getNeighborPairs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from torch import ops, Tensor
from typing import Tuple
from torch import empty, ops, Tensor
from typing import Optional, Tuple


def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1) -> Tuple[Tensor, Tensor]:
def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
'''
Returns indices and distances of atom pairs within a given cutoff distance.
Expand All @@ -16,6 +16,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
molecule, where most of the atoms are beyond the cutoff distance of each
other.
This function optionally supports periodic boundary conditions with
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
certain requirements:
`a[1] = a[2] = b[2] = 0`
`a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff`
`a[0] >= 2*b[0]`
`a[0] >= 2*c[0]`
`b[1] >= 2*c[1]`
These requirements correspond to a particular rotation of the system and
reduced form of the vectors, as well as the requirement that the cutoff be
no larger than half the box width.
Parameters
----------
positions: `torch.Tensor`
Expand All @@ -26,6 +40,10 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
max_num_neighbors: int, optional
Maximum number of neighbors per atom. If set to `-1` (default),
all possible combinations of atom pairs are included.
box_vectors: `torch.Tensor`, optional
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
Returns
-------
Expand Down Expand Up @@ -103,4 +121,6 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
tensor([1., 1., nan, nan, nan, nan]))
'''

return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors)
if box_vectors is None:
box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype)
return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors)
31 changes: 29 additions & 2 deletions src/pytorch/neighbors/getNeighborPairsCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ using torch::Scalar;
using torch::hstack;
using torch::vstack;
using torch::Tensor;
using torch::outer;
using torch::round;

static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_neighbors) {
const Scalar& max_num_neighbors,
const Tensor& box_vectors) {

TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
Expand All @@ -25,6 +28,25 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,

TORCH_CHECK(cutoff.to<double>() > 0, "Expected \"cutoff\" to be positive");

if (box_vectors.size(0) != 0) {
TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions");
TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)");
double v[3][3];
for (int i = 0; i < 3; i++)
for (int j = 0; j < 3; j++)
v[i][j] = box_vectors[i][j].item<double>();
double c = cutoff.to<double>();
TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0");
TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0");
TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0");
TORCH_CHECK(v[0][0] >= 2*c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff");
TORCH_CHECK(v[1][1] >= 2*c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff");
TORCH_CHECK(v[2][2] >= 2*c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff");
TORCH_CHECK(v[0][0] >= 2*v[1][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]");
TORCH_CHECK(v[0][0] >= 2*v[2][0], "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]");
TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]");
}

const int max_num_neighbors_ = max_num_neighbors.to<int>();
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
"Expected \"max_num_neighbors\" to be positive or equal to -1");
Expand All @@ -39,12 +61,17 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,

Tensor neighbors = vstack({rows, columns});
Tensor deltas = index_select(positions, 0, rows) - index_select(positions, 0, columns);
if (box_vectors.size(0) != 0) {
deltas -= outer(round(deltas.index({Slice(), 2})/box_vectors.index({2, 2})), box_vectors.index({2}));
deltas -= outer(round(deltas.index({Slice(), 1})/box_vectors.index({1, 1})), box_vectors.index({1}));
deltas -= outer(round(deltas.index({Slice(), 0})/box_vectors.index({0, 0})), box_vectors.index({0}));
}
Tensor distances = frobenius_norm(deltas, 1);

if (max_num_neighbors_ == -1) {
const Tensor mask = distances > cutoff;
neighbors.index_put_({Slice(), mask}, -1);
deltas = deltas.clone(); // Brake an autograd loop
deltas = deltas.clone(); // Break an autograd loop
deltas.index_put_({mask, Slice()}, NAN);
distances.index_put_({mask}, NAN);

Expand Down
40 changes: 31 additions & 9 deletions src/pytorch/neighbors/getNeighborPairsCUDA.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ template <typename scalar_t> __global__ void forward_kernel(
const Accessor<scalar_t, 2> positions,
const scalar_t cutoff2,
const bool store_all_pairs,
const bool use_periodic,
Accessor<int32_t, 1> i_curr_pair,
Accessor<int32_t, 2> neighbors,
Accessor<scalar_t, 2> deltas,
Accessor<scalar_t, 1> distances
Accessor<scalar_t, 1> distances,
Accessor<scalar_t, 2> box_vectors
) {
const int32_t index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= num_all_pairs) return;
Expand All @@ -43,9 +45,20 @@ template <typename scalar_t> __global__ void forward_kernel(
if (row * (row - 1) > 2 * index) row--;
const int32_t column = index - row * (row - 1) / 2;

const scalar_t delta_x = positions[row][0] - positions[column][0];
const scalar_t delta_y = positions[row][1] - positions[column][1];
const scalar_t delta_z = positions[row][2] - positions[column][2];
scalar_t delta_x = positions[row][0] - positions[column][0];
scalar_t delta_y = positions[row][1] - positions[column][1];
scalar_t delta_z = positions[row][2] - positions[column][2];
if (use_periodic) {
scalar_t scale3 = round(delta_z/box_vectors[2][2]);
delta_x -= scale3*box_vectors[2][0];
delta_y -= scale3*box_vectors[2][1];
delta_z -= scale3*box_vectors[2][2];
scalar_t scale2 = round(delta_y/box_vectors[1][1]);
delta_x -= scale2*box_vectors[1][0];
delta_y -= scale2*box_vectors[1][1];
scalar_t scale1 = round(delta_x/box_vectors[0][0]);
delta_x -= scale1*box_vectors[0][0];
}
const scalar_t distance2 = delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;

if (distance2 > cutoff2) return;
Expand Down Expand Up @@ -89,7 +102,8 @@ public:
static tensor_list forward(AutogradContext* ctx,
const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_neighbors) {
const Scalar& max_num_neighbors,
const Tensor& box_vectors) {

TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
Expand All @@ -100,6 +114,12 @@ public:
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
"Expected \"max_num_neighbors\" to be positive or equal to -1");

const bool use_periodic = (box_vectors.size(0) != 0);
if (use_periodic) {
TORCH_CHECK(box_vectors.dim() == 2, "Expected \"box_vectors\" to have two dimensions");
TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, "Expected \"box_vectors\" to have shape (3, 3)");
}

// Decide the algorithm
const bool store_all_pairs = max_num_neighbors_ == -1;
const int num_atoms = positions.size(0);
Expand All @@ -125,10 +145,12 @@ public:
get_accessor<scalar_t, 2>(positions),
cutoff_ * cutoff_,
store_all_pairs,
use_periodic,
get_accessor<int32_t, 1>(i_curr_pair),
get_accessor<int32_t, 2>(neighbors),
get_accessor<scalar_t, 2>(deltas),
get_accessor<scalar_t, 1>(distances));
get_accessor<scalar_t, 1>(distances),
get_accessor<scalar_t, 2>(box_vectors));
});

ctx->save_for_backward({neighbors, deltas, distances});
Expand Down Expand Up @@ -165,14 +187,14 @@ public:
get_accessor<scalar_t, 2>(grad_positions));
});

return {grad_positions, Tensor(), Tensor()};
return {grad_positions, Tensor(), Tensor(), Tensor()};
}
};

TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) {
m.impl("getNeighborPairs",
[](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors){
const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors);
[](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors, const Tensor& box_vectors){
const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors, box_vectors);
return make_tuple(results[0], results[1], results[2]);
});
}
2 changes: 1 addition & 1 deletion src/pytorch/neighbors/neighbors.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <torch/extension.h>

TORCH_LIBRARY(neighbors, m) {
m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors) -> (Tensor neighbors, Tensor deltas, Tensor distances)");
m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors) -> (Tensor neighbors, Tensor deltas, Tensor distances)");
}

0 comments on commit 8b2d427

Please sign in to comment.