Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error checking to CUDA version of getNeighborPairs #80

Merged
merged 46 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
657e748
Add error checking to CUDA version of getNeighborPairs
RaulPPelaez Jan 16, 2023
5339556
Add a new bool optional parameter to getNeighborPairs, setting it to …
RaulPPelaez Jan 19, 2023
8c13952
Remove unnecessarily static variable
RaulPPelaez Jan 19, 2023
928f123
Change the error handling of getNeighborPairs.
RaulPPelaez Mar 3, 2023
477e9cd
Make getNeighborPairs CUDA-graph compatible, add test for it
RaulPPelaez Mar 6, 2023
822c691
Remove incorrect comment
RaulPPelaez Mar 6, 2023
e46fe2d
Change not by !
RaulPPelaez Mar 6, 2023
e80cd5e
Move all torch.ops.load calls to the __init__.py scripts
RaulPPelaez Mar 7, 2023
2a7cd3a
Change how the location of libNNPOpsPyTorch.so is found at __init__ s…
RaulPPelaez Mar 9, 2023
e4df3cf
Remove spurious lines in CMakeLists.txt
RaulPPelaez Mar 9, 2023
d6eb763
Update again how libNNPOpsPyTorch.so is found in __init__.py
RaulPPelaez Mar 9, 2023
ca821c3
Remove redundant torch load
RaulPPelaez Mar 9, 2023
46ddf3d
Merge remote-tracking branch 'origin/master'
RaulPPelaez Mar 9, 2023
676f83b
Skip CUDA graph test if no GPU is available
RaulPPelaez Mar 9, 2023
d05656b
Remove incorrect path in __init__
RaulPPelaez Mar 14, 2023
4fb4b2e
Use relative path to load NNPOps library in __init__.py
RaulPPelaez Mar 15, 2023
bf56580
Copy test scripts to build directory, run them there
RaulPPelaez Mar 15, 2023
947f4d8
Remove unnecessary import
RaulPPelaez Mar 15, 2023
a258786
Merge branch 'fix_torch_load' into cuda_graphs
RaulPPelaez Mar 15, 2023
ae82f90
Some fixes for CUDA graph support in getNEighborPairs
RaulPPelaez Mar 17, 2023
4625684
Reverse logic for check_errors in getNeighborPairs.py
RaulPPelaez Mar 22, 2023
1376a5e
Merge branch 'cuda_graphs'
RaulPPelaez Mar 22, 2023
d711a3c
Reverse check_errors flag in the rest of the getNeighborPair-related …
RaulPPelaez Mar 22, 2023
400ceed
Clarify documentation on the error raised by getNeighborPairs
RaulPPelaez Mar 22, 2023
c36243b
Always return the number of found pairs in getNeighborPairs
RaulPPelaez Mar 22, 2023
5552a89
Merge remote-tracking branch 'origin/master'
RaulPPelaez Mar 22, 2023
8da1c5d
Revert "Always return the number of found pairs in getNeighborPairs"
RaulPPelaez Mar 22, 2023
829ee5b
Fix check_error interpretation in getNeighborPairs.py
RaulPPelaez Mar 22, 2023
73c3e58
Add return number of pairs functionality again
RaulPPelaez Mar 23, 2023
c2210f3
Update tests with new getNeighborPairs interface
RaulPPelaez Mar 23, 2023
fba2b46
Fix type decorator preventing jit.script from working on getNeighborP…
RaulPPelaez Mar 23, 2023
562d522
Remove sync_exceptions flag, simplifying the behavior and relation
RaulPPelaez Mar 29, 2023
751ee12
Remove unused function
RaulPPelaez Mar 29, 2023
ad8bbaf
Remove unnecessary synchronization in test
RaulPPelaez Mar 31, 2023
6593331
Clarify documentation of check_errors
RaulPPelaez Mar 31, 2023
75608cf
Clarify documentation of number_found_pairs
RaulPPelaez Mar 31, 2023
4c624e5
Clarify documentation of CUDA graph functionality
RaulPPelaez Mar 31, 2023
355860f
Remove obsolete comment
RaulPPelaez Mar 31, 2023
5ccc98f
Fix formatting
RaulPPelaez Mar 31, 2023
bc78d15
Fix formatting
RaulPPelaez Mar 31, 2023
e1a965a
Update documentation
RaulPPelaez Mar 31, 2023
130b13b
Change the (misleading) num_pairs variable name to max_num_pairs.
RaulPPelaez Mar 31, 2023
2d8d02b
Add test that checks if the max_num_neighbors per particle is
RaulPPelaez Mar 31, 2023
6a67cad
Merge remote-tracking branch 'origin/master'
RaulPPelaez Mar 31, 2023
90a584e
Change the meaning and name from max_num_neighbors (maximum number of…
RaulPPelaez Apr 4, 2023
c97a6f2
Fix typo in comment
RaulPPelaez Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 70 additions & 20 deletions src/pytorch/neighbors/TestNeighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs):

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))
max_num_neighbors = -1 if all_pairs else max(int(np.ceil(num_neighbors / num_atoms)), 1)
max_num_pairs = -1 if all_pairs else max(int(num_neighbors), 1)

# Compute results
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs)

# Check device
assert neighbors.device == positions.device
Expand All @@ -83,7 +83,7 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs):
neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances)

# Resize the reference
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors)
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs)

assert np.all(ref_neighbors == neighbors)
assert np.allclose(ref_deltas, deltas, equal_nan=True)
Expand All @@ -94,7 +94,7 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs):
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100, 1000])
@pytest.mark.parametrize('grad', ['deltas', 'distances', 'combined'])
def test_neighbor_grads(device, dtype, num_atoms, grad):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

Expand All @@ -114,8 +114,8 @@ def test_neighbor_grads(device, dtype, num_atoms, grad):
# Compute values using NNPOps
positions.requires_grad_(True)
print(positions)
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff)

assert pt.all(neighbors > -1)
assert pt.all(neighbors == ref_neighbors)
assert pt.allclose(deltas, ref_deltas)
Expand All @@ -133,28 +133,78 @@ def test_neighbor_grads(device, dtype, num_atoms, grad):
(deltas.sum() + distances.sum()).backward()
else:
raise ValueError('grad')

if dtype == pt.float32:
assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-3, rtol=1e-3)
else:
assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-8, rtol=1e-5)


# The following test is only run on the CPU. Running it on the GPU triggers a
# CUDA assertion, which causes all tests run after it to fail.

@pytest.mark.parametrize('device', ['cpu'])
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_too_many_neighbors(device, dtype):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

# 4 points result into 6 pairs, but there is a storage just for 4.
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
with pytest.raises(RuntimeError):
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
getNeighborPairs(positions, cutoff=1, max_num_neighbors=1)
pt.cuda.synchronize()
# checkErrors = True will raise due to exceeding neighbours
getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=True)

# checkErrors = False will never throw due to exceeding neighbours. In addition, the call will be compatible with CUDA graphs
neighbors, deltas, distances, number_found_pairs = getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=False)
assert number_found_pairs == 6

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_max_pairs_means_total(device, dtype):
if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')
# 4 points result into 6 pairs.
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
with pytest.raises(RuntimeError):
# checkErrors = True should raise due to exceeding neighbours
getNeighborPairs(positions, cutoff=1, max_num_pairs=5, check_errors=True)
getNeighborPairs(positions, cutoff=1, max_num_pairs=6, check_errors=True)

def test_is_cuda_graph_compatible():
if not pt.cuda.is_available():
pytest.skip('No GPU')
device = 'cuda'
dtype = pt.float32
num_atoms = 100
# Generate random positions
positions = 10 * pt.randn((num_atoms, 3), device=device, dtype=dtype)
cutoff = 5
# Get neighbor pairs
ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1))
ref_positions = positions.cpu().numpy()
ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]]
ref_distances = np.linalg.norm(ref_deltas, axis=1)

# Filter the neighbor pairs
mask = ref_distances > cutoff
ref_neighbors[:, mask] = -1
ref_deltas[mask, :] = np.nan
ref_distances[mask] = np.nan

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))

graph = pt.cuda.CUDAGraph()
s = pt.cuda.Stream()
s.wait_stream(pt.cuda.current_stream())
with pt.cuda.stream(s):
for _ in range(3):
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1)
pt.cuda.synchronize()

with pt.cuda.graph(graph):
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1)

graph.replay()
pt.cuda.synchronize()


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
Expand Down Expand Up @@ -187,10 +237,10 @@ def test_periodic_neighbors(device, dtype):

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))
max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1)
max_num_pairs = max(int(num_neighbors), 1)

# Compute results
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs, box_vectors=box_vectors)

# Check device
assert neighbors.device == positions.device
Expand All @@ -213,7 +263,7 @@ def test_periodic_neighbors(device, dtype):
neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances)

# Resize the reference
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors)
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs)

assert np.all(ref_neighbors == neighbors)
assert np.allclose(ref_deltas, deltas, equal_nan=True)
Expand All @@ -228,7 +278,7 @@ class ForceModule(pt.nn.Module):

def forward(self, positions):

neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=1.0)
mask = pt.isnan(distances)
distances = distances[~mask]
return pt.sum(distances**2)
Expand Down
97 changes: 59 additions & 38 deletions src/pytorch/neighbors/getNeighborPairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
from typing import Optional, Tuple


def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
'''
Returns indices and distances of atom pairs within a given cutoff distance.

If `max_num_neighbors == -1` (default), all the atom pairs are returned,
def getNeighborPairs(
positions: Tensor,
cutoff: float,
max_num_pairs: int = -1,
box_vectors: Optional[Tensor] = None,
check_errors: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Returns indices and distances of atom pairs within a given cutoff distance.

If `max_num_pairs == -1` (default), all the atom pairs are returned,
i.e. `num_pairs = num_atoms * (num_atoms + 1) / 2`. This is intended for
the small molecules, where almost all the atoms are within the cutoff
distance of each other.

If `max_num_neighbors > 0`, a fixed number of the atom pair are returned,
i.e. `num_pairs = num_atoms * max_num_neighbors`. This is indeded for large
molecule, where most of the atoms are beyond the cutoff distance of each
other.
If `max_num_pairs > 0`, a fixed number of the atom pairs are
returned. This is intended for large molecule, where most of the
atoms are beyond the cutoff distance of each other.

This function optionally supports periodic boundary conditions with
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
Expand All @@ -37,13 +41,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
data type has to be`torch.float32` or `torch.float64`.
cutoff: float
Maximum distance between atom pairs.
max_num_neighbors: int, optional
Maximum number of neighbors per atom. If set to `-1` (default),
max_num_pairs: int, optional
Maximum number of pairs (total number of neighbors). If set to `-1` (default),
all possible combinations of atom pairs are included.
box_vectors: `torch.Tensor`, optional
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
check_errors: bool, optional
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
If True, a RuntimeError is raised if more than max_num_pairs pairs are found.
The error checking requires synchronization, which adds cost and makes this function
incompatible with CUDA graphs. If this argument is False, no error checking is performed.
This makes it faster and compatible with CUDA graphs, but it is your responsibility
to check the return value for number_found_pairs to make sure that no neighbors were missed.
Default: False

Returns
-------
Expand All @@ -63,17 +74,27 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
If an atom pair is separated by a larger distance than the cutoff,
the distance is set to `NaN`.

number_found_pairs: `torch.Tensor`
Contains the total number of pairs found. Be aware that if
check_errors is False, this might be larger than
max_num_pairs. In that case, the output tensors contain
raimis marked this conversation as resolved.
Show resolved Hide resolved
only a subset of the pairs that were found, and the others are
omitted. Which pairs get omitted may vary between invocations.

Exceptions
----------
If `max_num_neighbors > 0` and too small, `RuntimeError` is raised.
If `max_num_pairs > 0` and too small, `RuntimeError` is raised if check_errors=True.

Note
----
The operation is compatible with CUDA Grahps, i.e. the shapes of the output
tensors are independed of the values of input tensors.
The operation can be compatible with CUDA Graphs: the shapes of
the output tensors are independent of the values of input tensors,
and no synchronization is performed.
For this to be true, check_errors must be False.

The CUDA implementation returns the atom pairs in non-determinist order,
if `max_num_neighbors > 0`.
if `max_num_pairs > 0`.


Examples
--------
Expand All @@ -88,39 +109,39 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 0.]]),
tensor([1., 2., 1.]))
tensor([1., 2., 1.]), tensor([3], dtype=torch.int32))

>>> getNeighborPairs(positions, cutoff=1.5) # doctest: +NORMALIZE_WHITESPACE
(tensor([[ 1, -1, 2],
[ 0, -1, 1]], dtype=torch.int32),
tensor([[1., 0., 0.],
[nan, nan, nan],
[1., 0., 0.]]),
tensor([1., nan, 1.]))
tensor([1., nan, 1.]), tensor([3], dtype=torch.int32))

>>> getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE
>>> getNeighborPairs(positions, cutoff=3.0, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE
(tensor([[ 1, 2, 2, -1, -1, -1],
[ 0, 0, 1, -1, -1, -1]], dtype=torch.int32),
tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]),
tensor([1., 2., 1., nan, nan, nan]))

>>> getNeighborPairs(positions, cutoff=1.5, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE
[ 0, 0, 1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]), tensor([1., 2., 1., nan, nan, nan]), tensor([6], dtype=torch.int32))

>>> getNeighborPairs(positions, cutoff=1.5, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE
(tensor([[ 1, 2, -1, -1, -1, -1],
[ 0, 1, -1, -1, -1, -1]], dtype=torch.int32),
tensor([[1., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]),
tensor([1., 1., nan, nan, nan, nan]))
'''
[ 0, 1, -1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]), tensor([1., 1., nan, nan, nan, nan]), tensor([6], dtype=torch.int32))

"""

if box_vectors is None:
box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype)
return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors)
neighbors, deltas, distances, number_found_pairs = ops.neighbors.getNeighborPairs(
positions, cutoff, max_num_pairs, box_vectors, check_errors
)
return neighbors, deltas, distances, number_found_pairs
39 changes: 23 additions & 16 deletions src/pytorch/neighbors/getNeighborPairsCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ using torch::Tensor;
using torch::outer;
using torch::round;

static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_neighbors,
const Tensor& box_vectors) {
static tuple<Tensor, Tensor, Tensor, Tensor> forward(const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_pairs,
const Tensor& box_vectors,
bool checkErrors) {

TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
Expand Down Expand Up @@ -47,9 +48,9 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]");
}

const int max_num_neighbors_ = max_num_neighbors.to<int>();
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
"Expected \"max_num_neighbors\" to be positive or equal to -1");
const int max_num_pairs_ = max_num_pairs.to<int>();
TORCH_CHECK(max_num_pairs_ > 0 || max_num_pairs_ == -1,
"Expected \"max_num_pairs\" to be positive or equal to -1");

const int num_atoms = positions.size(0);
const int num_pairs = num_atoms * (num_atoms - 1) / 2;
Expand All @@ -68,7 +69,7 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
}
Tensor distances = frobenius_norm(deltas, 1);

if (max_num_neighbors_ == -1) {
if (max_num_pairs_ == -1) {
const Tensor mask = distances > cutoff;
neighbors.index_put_({Slice(), mask}, -1);
deltas = deltas.clone(); // Break an autograd loop
Expand All @@ -82,20 +83,26 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
deltas = deltas.index({mask, Slice()});
distances = distances.index({mask});

const int num_pad = num_atoms * max_num_neighbors_ - distances.size(0);
TORCH_CHECK(num_pad >= 0,
"The maximum number of pairs has been exceed! Increase \"max_num_neighbors\"");

const int num_pad = max_num_pairs_ - distances.size(0);
if (checkErrors) {
TORCH_CHECK(num_pad >= 0,
"The maximum number of pairs has been exceed! Increase \"max_num_pairs\"");
}
if (num_pad > 0) {
neighbors = hstack({neighbors, full({2, num_pad}, -1, neighbors.options())});
deltas = vstack({deltas, full({num_pad, 3}, NAN, deltas.options())});
distances = hstack({distances, full({num_pad}, NAN, distances.options())});
}
}

return {neighbors, deltas, distances};
Tensor num_pairs_found = torch::empty(1, indices.options().dtype(kInt32));
num_pairs_found[0] = distances.size(0);
return {neighbors, deltas, distances, num_pairs_found};
}

TORCH_LIBRARY_IMPL(neighbors, CPU, m) {
m.impl("getNeighborPairs", &forward);
}
m.impl("getNeighborPairs",
[](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_pairs,
const Tensor& box_vectors, const bool &checkErrors){
return forward(positions, cutoff, max_num_pairs, box_vectors, checkErrors);
});
}
Loading