diff --git a/src/pytorch/neighbors/TestNeighbors.py b/src/pytorch/neighbors/TestNeighbors.py index 5d4c4512..7d830812 100644 --- a/src/pytorch/neighbors/TestNeighbors.py +++ b/src/pytorch/neighbors/TestNeighbors.py @@ -56,10 +56,10 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs): # Find the number of neighbors num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) - max_num_neighbors = -1 if all_pairs else max(int(np.ceil(num_neighbors / num_atoms)), 1) + max_num_pairs = -1 if all_pairs else max(int(num_neighbors), 1) # Compute results - neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors) + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs) # Check device assert neighbors.device == positions.device @@ -83,7 +83,7 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs): neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances) # Resize the reference - ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors) + ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs) assert np.all(ref_neighbors == neighbors) assert np.allclose(ref_deltas, deltas, equal_nan=True) @@ -94,7 +94,7 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs): @pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100, 1000]) @pytest.mark.parametrize('grad', ['deltas', 'distances', 'combined']) def test_neighbor_grads(device, dtype, num_atoms, grad): - + if not pt.cuda.is_available() and device == 'cuda': pytest.skip('No GPU') @@ -114,8 +114,8 @@ def test_neighbor_grads(device, dtype, num_atoms, grad): # Compute values using NNPOps positions.requires_grad_(True) print(positions) - neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff) - + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff) + assert pt.all(neighbors > -1) assert pt.all(neighbors == ref_neighbors) assert pt.allclose(deltas, ref_deltas) @@ -133,28 +133,78 @@ def test_neighbor_grads(device, dtype, num_atoms, grad): (deltas.sum() + distances.sum()).backward() else: raise ValueError('grad') - + if dtype == pt.float32: assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-3, rtol=1e-3) else: assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-8, rtol=1e-5) -# The following test is only run on the CPU. Running it on the GPU triggers a -# CUDA assertion, which causes all tests run after it to fail. - -@pytest.mark.parametrize('device', ['cpu']) +@pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) def test_too_many_neighbors(device, dtype): - if not pt.cuda.is_available() and device == 'cuda': pytest.skip('No GPU') - # 4 points result into 6 pairs, but there is a storage just for 4. + positions = pt.zeros((4, 3,), device=device, dtype=dtype) with pytest.raises(RuntimeError): - positions = pt.zeros((4, 3,), device=device, dtype=dtype) - getNeighborPairs(positions, cutoff=1, max_num_neighbors=1) - pt.cuda.synchronize() + # checkErrors = True will raise due to exceeding neighbours + getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=True) + + # checkErrors = False will never throw due to exceeding neighbours. In addition, the call will be compatible with CUDA graphs + neighbors, deltas, distances, number_found_pairs = getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=False) + assert number_found_pairs == 6 + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) +def test_max_pairs_means_total(device, dtype): + if not pt.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + # 4 points result into 6 pairs. + positions = pt.zeros((4, 3,), device=device, dtype=dtype) + with pytest.raises(RuntimeError): + # checkErrors = True should raise due to exceeding neighbours + getNeighborPairs(positions, cutoff=1, max_num_pairs=5, check_errors=True) + getNeighborPairs(positions, cutoff=1, max_num_pairs=6, check_errors=True) + +def test_is_cuda_graph_compatible(): + if not pt.cuda.is_available(): + pytest.skip('No GPU') + device = 'cuda' + dtype = pt.float32 + num_atoms = 100 + # Generate random positions + positions = 10 * pt.randn((num_atoms, 3), device=device, dtype=dtype) + cutoff = 5 + # Get neighbor pairs + ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1)) + ref_positions = positions.cpu().numpy() + ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]] + ref_distances = np.linalg.norm(ref_deltas, axis=1) + + # Filter the neighbor pairs + mask = ref_distances > cutoff + ref_neighbors[:, mask] = -1 + ref_deltas[mask, :] = np.nan + ref_distances[mask] = np.nan + + # Find the number of neighbors + num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) + + graph = pt.cuda.CUDAGraph() + s = pt.cuda.Stream() + s.wait_stream(pt.cuda.current_stream()) + with pt.cuda.stream(s): + for _ in range(3): + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1) + pt.cuda.synchronize() + + with pt.cuda.graph(graph): + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1) + + graph.replay() + pt.cuda.synchronize() + @pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) @@ -187,10 +237,10 @@ def test_periodic_neighbors(device, dtype): # Find the number of neighbors num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) - max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1) + max_num_pairs = max(int(num_neighbors), 1) # Compute results - neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors) + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs, box_vectors=box_vectors) # Check device assert neighbors.device == positions.device @@ -213,7 +263,7 @@ def test_periodic_neighbors(device, dtype): neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances) # Resize the reference - ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors) + ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs) assert np.all(ref_neighbors == neighbors) assert np.allclose(ref_deltas, deltas, equal_nan=True) @@ -228,7 +278,7 @@ class ForceModule(pt.nn.Module): def forward(self, positions): - neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0) + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=1.0) mask = pt.isnan(distances) distances = distances[~mask] return pt.sum(distances**2) diff --git a/src/pytorch/neighbors/getNeighborPairs.py b/src/pytorch/neighbors/getNeighborPairs.py index b6f30072..12a4b03c 100644 --- a/src/pytorch/neighbors/getNeighborPairs.py +++ b/src/pytorch/neighbors/getNeighborPairs.py @@ -2,19 +2,23 @@ from typing import Optional, Tuple -def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: - ''' - Returns indices and distances of atom pairs within a given cutoff distance. - - If `max_num_neighbors == -1` (default), all the atom pairs are returned, +def getNeighborPairs( + positions: Tensor, + cutoff: float, + max_num_pairs: int = -1, + box_vectors: Optional[Tensor] = None, + check_errors: bool = False +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Returns indices and distances of atom pairs within a given cutoff distance. + + If `max_num_pairs == -1` (default), all the atom pairs are returned, i.e. `num_pairs = num_atoms * (num_atoms + 1) / 2`. This is intended for the small molecules, where almost all the atoms are within the cutoff distance of each other. - If `max_num_neighbors > 0`, a fixed number of the atom pair are returned, - i.e. `num_pairs = num_atoms * max_num_neighbors`. This is indeded for large - molecule, where most of the atoms are beyond the cutoff distance of each - other. + If `max_num_pairs > 0`, a fixed number of the atom pairs are + returned. This is intended for large molecule, where most of the + atoms are beyond the cutoff distance of each other. This function optionally supports periodic boundary conditions with arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy @@ -37,13 +41,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = data type has to be`torch.float32` or `torch.float64`. cutoff: float Maximum distance between atom pairs. - max_num_neighbors: int, optional - Maximum number of neighbors per atom. If set to `-1` (default), + max_num_pairs: int, optional + Maximum number of pairs (total number of neighbors). If set to `-1` (default), all possible combinations of atom pairs are included. box_vectors: `torch.Tensor`, optional The vectors defining the periodic box. This must have shape `(3, 3)`, where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. If this is omitted, periodic boundary conditions are not applied. + check_errors: bool, optional + If True, a RuntimeError is raised if more than max_num_pairs pairs are found. + The error checking requires synchronization, which adds cost and makes this function + incompatible with CUDA graphs. If this argument is False, no error checking is performed. + This makes it faster and compatible with CUDA graphs, but it is your responsibility + to check the return value for number_found_pairs to make sure that no neighbors were missed. + Default: False Returns ------- @@ -63,17 +74,27 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = If an atom pair is separated by a larger distance than the cutoff, the distance is set to `NaN`. + number_found_pairs: `torch.Tensor` + Contains the total number of pairs found. Be aware that if + check_errors is False, this might be larger than + max_num_pairs. In that case, the output tensors contain + only a subset of the pairs that were found, and the others are + omitted. Which pairs get omitted may vary between invocations. + Exceptions ---------- - If `max_num_neighbors > 0` and too small, `RuntimeError` is raised. + If `max_num_pairs > 0` and too small, `RuntimeError` is raised if check_errors=True. Note ---- - The operation is compatible with CUDA Grahps, i.e. the shapes of the output - tensors are independed of the values of input tensors. + The operation can be compatible with CUDA Graphs: the shapes of + the output tensors are independent of the values of input tensors, + and no synchronization is performed. + For this to be true, check_errors must be False. The CUDA implementation returns the atom pairs in non-determinist order, - if `max_num_neighbors > 0`. + if `max_num_pairs > 0`. + Examples -------- @@ -88,7 +109,7 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = tensor([[1., 0., 0.], [2., 0., 0.], [1., 0., 0.]]), - tensor([1., 2., 1.])) + tensor([1., 2., 1.]), tensor([3], dtype=torch.int32)) >>> getNeighborPairs(positions, cutoff=1.5) # doctest: +NORMALIZE_WHITESPACE (tensor([[ 1, -1, 2], @@ -96,31 +117,31 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = tensor([[1., 0., 0.], [nan, nan, nan], [1., 0., 0.]]), - tensor([1., nan, 1.])) + tensor([1., nan, 1.]), tensor([3], dtype=torch.int32)) - >>> getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE + >>> getNeighborPairs(positions, cutoff=3.0, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE (tensor([[ 1, 2, 2, -1, -1, -1], - [ 0, 0, 1, -1, -1, -1]], dtype=torch.int32), - tensor([[1., 0., 0.], - [2., 0., 0.], - [1., 0., 0.], - [nan, nan, nan], - [nan, nan, nan], - [nan, nan, nan]]), - tensor([1., 2., 1., nan, nan, nan])) - - >>> getNeighborPairs(positions, cutoff=1.5, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE + [ 0, 0, 1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.], + [2., 0., 0.], + [1., 0., 0.], + [nan, nan, nan], + [nan, nan, nan], + [nan, nan, nan]]), tensor([1., 2., 1., nan, nan, nan]), tensor([6], dtype=torch.int32)) + + >>> getNeighborPairs(positions, cutoff=1.5, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE (tensor([[ 1, 2, -1, -1, -1, -1], - [ 0, 1, -1, -1, -1, -1]], dtype=torch.int32), - tensor([[1., 0., 0.], - [1., 0., 0.], - [nan, nan, nan], - [nan, nan, nan], - [nan, nan, nan], - [nan, nan, nan]]), - tensor([1., 1., nan, nan, nan, nan])) - ''' + [ 0, 1, -1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.], + [1., 0., 0.], + [nan, nan, nan], + [nan, nan, nan], + [nan, nan, nan], + [nan, nan, nan]]), tensor([1., 1., nan, nan, nan, nan]), tensor([6], dtype=torch.int32)) + + """ if box_vectors is None: box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype) - return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors) \ No newline at end of file + neighbors, deltas, distances, number_found_pairs = ops.neighbors.getNeighborPairs( + positions, cutoff, max_num_pairs, box_vectors, check_errors + ) + return neighbors, deltas, distances, number_found_pairs diff --git a/src/pytorch/neighbors/getNeighborPairsCPU.cpp b/src/pytorch/neighbors/getNeighborPairsCPU.cpp index 9df95b85..d63e24b2 100644 --- a/src/pytorch/neighbors/getNeighborPairsCPU.cpp +++ b/src/pytorch/neighbors/getNeighborPairsCPU.cpp @@ -16,10 +16,11 @@ using torch::Tensor; using torch::outer; using torch::round; -static tuple forward(const Tensor& positions, - const Scalar& cutoff, - const Scalar& max_num_neighbors, - const Tensor& box_vectors) { +static tuple forward(const Tensor& positions, + const Scalar& cutoff, + const Scalar& max_num_pairs, + const Tensor& box_vectors, + bool checkErrors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); @@ -47,9 +48,9 @@ static tuple forward(const Tensor& positions, TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); } - const int max_num_neighbors_ = max_num_neighbors.to(); - TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, - "Expected \"max_num_neighbors\" to be positive or equal to -1"); + const int max_num_pairs_ = max_num_pairs.to(); + TORCH_CHECK(max_num_pairs_ > 0 || max_num_pairs_ == -1, + "Expected \"max_num_pairs\" to be positive or equal to -1"); const int num_atoms = positions.size(0); const int num_pairs = num_atoms * (num_atoms - 1) / 2; @@ -68,7 +69,7 @@ static tuple forward(const Tensor& positions, } Tensor distances = frobenius_norm(deltas, 1); - if (max_num_neighbors_ == -1) { + if (max_num_pairs_ == -1) { const Tensor mask = distances > cutoff; neighbors.index_put_({Slice(), mask}, -1); deltas = deltas.clone(); // Break an autograd loop @@ -82,20 +83,26 @@ static tuple forward(const Tensor& positions, deltas = deltas.index({mask, Slice()}); distances = distances.index({mask}); - const int num_pad = num_atoms * max_num_neighbors_ - distances.size(0); - TORCH_CHECK(num_pad >= 0, - "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); - + const int num_pad = max_num_pairs_ - distances.size(0); + if (checkErrors) { + TORCH_CHECK(num_pad >= 0, + "The maximum number of pairs has been exceed! Increase \"max_num_pairs\""); + } if (num_pad > 0) { neighbors = hstack({neighbors, full({2, num_pad}, -1, neighbors.options())}); deltas = vstack({deltas, full({num_pad, 3}, NAN, deltas.options())}); distances = hstack({distances, full({num_pad}, NAN, distances.options())}); } } - - return {neighbors, deltas, distances}; + Tensor num_pairs_found = torch::empty(1, indices.options().dtype(kInt32)); + num_pairs_found[0] = distances.size(0); + return {neighbors, deltas, distances, num_pairs_found}; } TORCH_LIBRARY_IMPL(neighbors, CPU, m) { - m.impl("getNeighborPairs", &forward); -} \ No newline at end of file + m.impl("getNeighborPairs", + [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_pairs, + const Tensor& box_vectors, const bool &checkErrors){ + return forward(positions, cutoff, max_num_pairs, box_vectors, checkErrors); + }); +} diff --git a/src/pytorch/neighbors/getNeighborPairsCUDA.cu b/src/pytorch/neighbors/getNeighborPairsCUDA.cu index 2d820a4a..23540af6 100644 --- a/src/pytorch/neighbors/getNeighborPairsCUDA.cu +++ b/src/pytorch/neighbors/getNeighborPairsCUDA.cu @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include #include @@ -64,14 +66,15 @@ template __global__ void forward_kernel( if (distance2 > cutoff2) return; const int32_t i_pair = store_all_pairs ? index : atomicAdd(&i_curr_pair[0], 1); - assert(i_pair < neighbors.size(1)); - - neighbors[0][i_pair] = row; - neighbors[1][i_pair] = column; - deltas[i_pair][0] = delta_x; - deltas[i_pair][1] = delta_y; - deltas[i_pair][2] = delta_z; - distances[i_pair] = sqrt_(distance2); + //We handle too many neighbors outside of the kernel + if (i_pair < neighbors.size(1)) { + neighbors[0][i_pair] = row; + neighbors[1][i_pair] = column; + deltas[i_pair][0] = delta_x; + deltas[i_pair][1] = delta_y; + deltas[i_pair][2] = delta_z; + distances[i_pair] = sqrt_(distance2); + } } template __global__ void backward_kernel( @@ -102,17 +105,17 @@ public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Scalar& cutoff, - const Scalar& max_num_neighbors, - const Tensor& box_vectors) { - + const Scalar& max_num_pairs, + const Tensor& box_vectors, + bool checkErrors) { + const auto stream = getCurrentCUDAStream(positions.get_device()); + const CUDAStreamGuard guard(stream); TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - - const int max_num_neighbors_ = max_num_neighbors.to(); - TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, - "Expected \"max_num_neighbors\" to be positive or equal to -1"); + TORCH_CHECK(max_num_pairs.toInt() > 0 || max_num_pairs.toInt() == -1, + "Expected \"max_num_pairs\" to be positive or equal to -1"); const bool use_periodic = (box_vectors.size(0) != 0); if (use_periodic) { @@ -121,25 +124,23 @@ public: } // Decide the algorithm - const bool store_all_pairs = max_num_neighbors_ == -1; + const bool store_all_pairs = max_num_pairs.toInt() == -1; const int num_atoms = positions.size(0); const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int num_pairs = store_all_pairs ? num_all_pairs : num_atoms * max_num_neighbors_; + const int max_num_pairs_ = store_all_pairs ? num_all_pairs : (max_num_pairs.toInt()); const int num_threads = 128; const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); - const auto stream = getCurrentCUDAStream(positions.get_device()); const TensorOptions options = positions.options(); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); - const Tensor deltas = full({num_pairs, 3}, NAN, options); - const Tensor distances = full(num_pairs, NAN, options); + const Tensor neighbors = full({2, max_num_pairs_}, -1, options.dtype(kInt32)); + const Tensor deltas = full({max_num_pairs_, 3}, NAN, options); + const Tensor distances = full(max_num_pairs_, NAN, options); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "getNeighborPairs::forward", [&]() { - const CUDAStreamGuard guard(stream); - const scalar_t cutoff_ = cutoff.to(); - TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); + const scalar_t cutoff_ = cutoff.to(); + TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); forward_kernel<<>>( num_all_pairs, get_accessor(positions), @@ -152,11 +153,14 @@ public: get_accessor(distances), get_accessor(box_vectors)); }); - + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + if (checkErrors) { + int num_found_pairs = i_curr_pair.item(); + TORCH_CHECK(num_found_pairs <= max_num_pairs_, "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); + } ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; - - return {neighbors, deltas, distances}; + return {neighbors, deltas, distances, i_curr_pair}; } static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { @@ -187,14 +191,16 @@ public: get_accessor(grad_positions)); }); - return {grad_positions, Tensor(), Tensor(), Tensor()}; + return {grad_positions, Tensor(), Tensor(), Tensor(), Tensor(), Tensor()}; } }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("getNeighborPairs", - [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors, const Tensor& box_vectors){ - const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors, box_vectors); - return make_tuple(results[0], results[1], results[2]); - }); -} \ No newline at end of file + m.impl("getNeighborPairs", + [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_pairs, + const Tensor& box_vectors, const bool &checkErrors){ + const tensor_list results = Autograd::apply(positions, cutoff, max_num_pairs, + box_vectors, checkErrors); + return make_tuple(results[0], results[1], results[2], results[3]); + }); +} diff --git a/src/pytorch/neighbors/neighbors.cpp b/src/pytorch/neighbors/neighbors.cpp index d8dd5c5b..e5911907 100644 --- a/src/pytorch/neighbors/neighbors.cpp +++ b/src/pytorch/neighbors/neighbors.cpp @@ -1,5 +1,5 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors) -> (Tensor neighbors, Tensor deltas, Tensor distances)"); -} \ No newline at end of file + m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors, bool checkErrors) -> (Tensor neighbors, Tensor deltas, Tensor distances, Tensor num_pairs)"); +}