diff --git a/CMakeLists.txt b/CMakeLists.txt index d372bdf..5529ff6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,9 @@ set(SRC_FILES src/ani/CpuANISymmetryFunctions.cpp src/pytorch/neighbors/getNeighborPairsCPU.cpp src/pytorch/neighbors/getNeighborPairsCUDA.cu src/pytorch/neighbors/neighbors.cpp + src/pytorch/pme/pmeCPU.cpp + src/pytorch/pme/pmeCUDA.cu + src/pytorch/pme/pme.cpp src/schnet/CpuCFConv.cpp src/schnet/CudaCFConv.cu) @@ -63,6 +66,7 @@ add_custom_target(copy_test ALL ${CMAKE_SOURCE_DIR}/src/pytorch/Test*.py ${CMAKE_SOURCE_DIR}/src/pytorch/neighbors/Test*.py ${CMAKE_SOURCE_DIR}/src/pytorch/neighbors/getNeighborPairs.py + ${CMAKE_SOURCE_DIR}/src/pytorch/pme/Test*.py ${CMAKE_BINARY_DIR}/test COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/src/pytorch/molecules @@ -91,3 +95,6 @@ install(FILES src/pytorch/__init__.py install(FILES src/pytorch/neighbors/__init__.py src/pytorch/neighbors/getNeighborPairs.py DESTINATION ${Python3_SITEARCH}/${NAME}/neighbors) +install(FILES src/pytorch/pme/__init__.py + src/pytorch/pme/pme.py + DESTINATION ${Python3_SITEARCH}/${NAME}/pme) diff --git a/src/pytorch/pme/TestPme.py b/src/pytorch/pme/TestPme.py new file mode 100644 index 0000000..b0d2a2a --- /dev/null +++ b/src/pytorch/pme/TestPme.py @@ -0,0 +1,318 @@ +import torch +import pytest +import numpy as np +from NNPOps.pme import PME + +class PmeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pme = PME(14, 15, 16, 5, 4.985823141035867, 138.935, torch.zeros(9, 0, dtype=torch.int32)) + + def forward(self, positions, charges, box_vectors): + edir = self.pme.compute_direct(positions, charges, 0.5, box_vectors) + erecip = self.pme.compute_reciprocal(positions, charges, box_vectors) + return edir+erecip + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_rectangular(device): + """Test PME on a rectangular box.""" + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + pme = PME(14, 15, 16, 5, 4.985823141035867, 138.935, torch.zeros(9, 0, dtype=torch.int32)) + pos = [[0.7713206433, 0.02075194936, 0.6336482349], + [0.7488038825, 0.4985070123, 0.2247966455], + [0.1980628648, 0.7605307122, 0.1691108366], + [0.08833981417, 0.6853598184, 0.9533933462], + [0.003948266328, 0.5121922634, 0.8126209617], + [0.6125260668, 0.7217553174, 0.2918760682], + [0.9177741225, 0.7145757834, 0.542544368], + [0.1421700476, 0.3733407601, 0.6741336151], + [0.4418331744, 0.4340139933, 0.6177669785]] + positions = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device=device) + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, device=device) + box_vectors = torch.tensor([[1, 0, 0], [0, 1.1, 0], [0, 0, 1.2]], dtype=torch.float32, device=device) + + # Compare forces and energies to values computed with OpenMM. + + edirect = pme.compute_direct(positions, charges, 0.5, box_vectors) + assert np.allclose(0.5811535194516182, edirect.detach().cpu().numpy()) + erecip = pme.compute_reciprocal(positions, charges, box_vectors) + assert np.allclose(-90.92361028496651, erecip.detach().cpu().numpy()) + expected_ddirect = [[-0.4068958163, 1.128490567, 0.2531163692], + [8.175477028, -15.20702648, -5.499810219], + [-0.2548360825, 0.003096142784, -0.67370224], + [0.09854402393, 0.5804504156, 1.063418627], + [-0, -0, -0], + [-7.859698296, 14.16478539, 5.236941814], + [0.684042871, -1.312145352, 0.7057141662], + [30.47141075, 6.726415634, -6.697656631], + [-30.90804291, -6.084065914, 5.611977577]] + expected_drecip = [[-0.6407046318, -27.59628105, -3.745499372], + [30.76446915, -27.10591507, -82.14082336], + [-15.06353951, 10.37030602, -38.38755035], + [-7.421859741, 21.9861393, 39.86354828], + [-0, -0, -0], + [-13.09759808, 6.393665314, 34.15939713], + [19.53832817, -59.55260849, 33.96843338], + [122.5542908, 60.35510254, -27.44270515], + [-136.679245, 15.14429855, 43.89074326]] + edirect.backward() + assert np.allclose(expected_ddirect, positions.grad.cpu().numpy(), rtol=1e-4) + positions.grad.zero_() + erecip.backward() + assert np.allclose(expected_drecip, positions.grad.cpu().numpy(), rtol=1e-4) + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_triclinic(device): + """Test PME on a triclinic box.""" + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + pme = PME(14, 16, 15, 5, 5.0, 138.935, torch.zeros(9, 0, dtype=torch.int32)) + pos = [[1.31396193, -0.9377441519, 0.9009447048], + [1.246411648, 0.4955210369, -0.3256100634], + [-0.4058114057, 1.281592137, -0.4926674903], + [-0.7349805575, 1.056079455, 1.860180039], + [-0.988155201, 0.5365767902, 1.437862885], + [0.8375782005, 1.165265952, -0.1243717955], + [1.753322368, 1.14372735, 0.627633104], + [-0.5734898572, 0.1200222802, 1.022400845], + [0.3254995233, 0.30204198, 0.8533009354]] + positions = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device=device) + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, device=device) + box_vectors = torch.tensor([[1, 0, 0], [-0.1, 1.2, 0], [0.2, -0.15, 1.1]], dtype=torch.float32, device=device) + + # Compare forces and energies to values computed with OpenMM. + + edirect = pme.compute_direct(positions, charges, 0.5, box_vectors) + assert np.allclose(-178.86083489656448, edirect.detach().cpu().numpy()) + erecip = pme.compute_reciprocal(positions, charges, box_vectors) + assert np.allclose(-200.9420623172533, erecip.detach().cpu().numpy()) + expected_ddirect = [[-1000.97644, -326.2085571, 373.3143005], + [401.765686, 153.7181702, -278.0073242], + [2140.490723, -633.4395752, -1059.523071], + [-1.647740602, 10.02025795, 0.2182842493], + [-0, -0, -0], + [0.05209997296, -2.530653, 3.196420431], + [-2139.176758, 633.9973145, 1060.562622], + [13.49786377, 11.52490139, -10.12783146], + [585.994812, 152.9181519, -89.63345337]] + expected_drecip = [[-162.9051514, 32.17734528, -77.43495178], + [11.11517906, 52.98329163, -83.18161011], + [34.50453186, 8.428194046, -4.691772938], + [-12.71308613, 20.7514267, -13.68377304], + [-0, -0, -0], + [8.277475357, -3.927520275, 13.88403988], + [-34.93006897, -7.739934444, 8.986465454], + [45.33776474, -36.9358139, 40.34444809], + [111.2698975, -65.63329315, 115.8478012]] + edirect.backward() + assert np.allclose(expected_ddirect, positions.grad.cpu().numpy(), rtol=1e-4) + positions.grad.zero_() + erecip.backward() + assert np.allclose(expected_drecip, positions.grad.cpu().numpy(), rtol=1e-4) + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_exclusions(device): + """Test PME with exclusions.""" + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + pos = [[1.31396193, -0.9377441519, 0.9009447048], + [1.246411648, 0.4955210369, -0.3256100634], + [-0.4058114057, 1.281592137, -0.4926674903], + [-0.7349805575, 1.056079455, 1.860180039], + [-0.988155201, 0.5365767902, 1.437862885], + [0.8375782005, 1.165265952, -0.1243717955], + [1.753322368, 1.14372735, 0.627633104], + [-0.5734898572, 0.1200222802, 1.022400845], + [0.3254995233, 0.30204198, 0.8533009354]] + excl = [[3, -1], + [-1, -1], + [-1, 3], + [0, 2], + [-1, -1], + [-1, -1], + [-1, -1], + [-1, 8], + [7, -1]] + positions = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device=device) + exclusions = torch.tensor(excl, dtype=torch.int32, device=device) + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, device=device) + box_vectors = torch.tensor([[1, 0, 0], [-0.1, 1.2, 0], [0.2, -0.15, 1.1]], dtype=torch.float32, device=device) + pme = PME(14, 16, 15, 5, 5.0, 138.935, exclusions) + + # Compare forces and energies to values computed with OpenMM. + + edirect = pme.compute_direct(positions, charges, 0.5, box_vectors) + assert np.allclose(-204.22671127319336, edirect.detach().cpu().numpy()) + erecip = pme.compute_reciprocal(positions, charges, box_vectors) + assert np.allclose(-200.9420623172533, erecip.detach().cpu().numpy()) + expected_ddirect = [[-998.2406773, -314.4639407, 379.7956738], + [401.7656421, 153.7181283, -278.0072042], + [2136.789297, -634.4331203, -1062.13192], + [-0.6838558404, -0.7345126528, -3.655667043], + [-0, -0, -0], + [0.05210044985, -2.530651058, 3.196419874], + [-2139.175743, 634.0007806, 1060.564263], + [21.9532636, -40.74009123, 38.42738517], + [577.5399728, 205.183407, -138.1889512]] + expected_drecip = [[-162.9051514, 32.17734528, -77.43495178], + [11.11517906, 52.98329163, -83.18161011], + [34.50453186, 8.428194046, -4.691772938], + [-12.71308613, 20.7514267, -13.68377304], + [-0, -0, -0], + [8.277475357, -3.927520275, 13.88403988], + [-34.93006897, -7.739934444, 8.986465454], + [45.33776474, -36.9358139, 40.34444809], + [111.2698975, -65.63329315, 115.8478012]] + edirect.backward() + assert np.allclose(expected_ddirect, positions.grad.cpu().numpy(), rtol=1e-4) + positions.grad.zero_() + erecip.backward() + assert np.allclose(expected_drecip, positions.grad.cpu().numpy(), rtol=1e-4) + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_charge_deriv(device): + """Test derivatives with respect to charge.""" + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + pos = [[0.7713206433, 0.02075194936, 0.6336482349], + [0.7488038825, 0.4985070123, 0.2247966455], + [0.1980628648, 0.7605307122, 0.1691108366], + [0.08833981417, 0.6853598184, 0.9533933462], + [0.003948266328, 0.5121922634, 0.8126209617], + [0.6125260668, 0.7217553174, 0.2918760682], + [0.9177741225, 0.7145757834, 0.542544368], + [0.1421700476, 0.3733407601, 0.6741336151], + [0.4418331744, 0.4340139933, 0.6177669785]] + excl = [[6, -1], + [-1, -1], + [-1, -1], + [6, -1], + [-1, -1], + [-1, -1], + [0, 3], + [-1, -1], + [-1, -1]] + positions = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device=device) + exclusions = torch.tensor(excl, dtype=torch.int32, device=device) + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, requires_grad=True, device=device) + box_vectors = torch.tensor([[1, 0, 0], [0,1.1, 0], [0, 0, 1.2]], dtype=torch.float32, device=device) + pme = PME(14, 15, 16, 5, 4.985823141035867, 138.935, exclusions) + + # Compute derivatives of the energies with respect to charges. + + edir = pme.compute_direct(positions, charges, 0.5, box_vectors) + erecip = pme.compute_reciprocal(positions, charges, box_vectors) + edir.backward(retain_graph=True) + ddir = charges.grad.clone().detach().cpu().numpy() + charges.grad.zero_() + erecip.backward(retain_graph=True) + drecip = charges.grad.clone().detach().cpu().numpy() + + # Compute finite difference approximations from two displaced inputs. + + delta = 0.001 + for i in range(len(charges)): + c1 = charges.clone() + c1[i] += delta + edir1 = pme.compute_direct(positions, c1, 0.5, box_vectors).detach().cpu().numpy() + erecip1 = pme.compute_reciprocal(positions, c1, box_vectors).detach().cpu().numpy() + c2 = charges.clone() + c2[i] -= delta + edir2 = pme.compute_direct(positions, c2, 0.5, box_vectors).detach().cpu().numpy() + erecip2 = pme.compute_reciprocal(positions, c2, box_vectors).detach().cpu().numpy() + assert np.allclose(ddir[i], (edir1-edir2)/(2*delta), rtol=1e-3, atol=1e-3) + assert np.allclose(drecip[i], (erecip1-erecip2)/(2*delta), rtol=1e-3, atol=1e-3) + + # Make sure the chain rule is applied properly. + + charges.grad.zero_() + (2.5*edir).backward() + ddir2 = charges.grad.clone().detach().cpu().numpy() + charges.grad.zero_() + (2.5*erecip).backward() + drecip2 = charges.grad.clone().detach().cpu().numpy() + assert np.allclose(2.5*ddir, ddir2) + assert np.allclose(2.5*drecip, drecip2) + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_jit(device): + """Test that the model can be JIT compiled.""" + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + m1 = PmeModule() + m2 = torch.jit.script(m1) + torch.manual_seed(10) + positions = 3*torch.rand((9, 3), dtype=torch.float32, device=device)-1 + positions.requires_grad_() + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, device=device) + box_vectors = torch.tensor([[1, 0, 0], [0,1.1, 0], [0, 0, 1.2]], dtype=torch.float32, device=device) + e1 = m1(positions, charges, box_vectors) + e2 = m2(positions, charges, box_vectors) + assert np.allclose(e1.detach().cpu().numpy(), e2.detach().cpu().numpy()) + e1.backward() + d1 = positions.grad.detach().cpu().numpy() + positions.grad.zero_() + e2.backward() + d2 = positions.grad.detach().cpu().numpy() + assert np.allclose(d1, d2) + +def test_cuda_graph(): + """Test that PME works with CUDA graphs.""" + if not torch.cuda.is_available(): + pytest.skip('No GPU') + device = 'cuda' + pme = PmeModule() + torch.manual_seed(10) + positions = 3*torch.rand((9, 3), dtype=torch.float32, device=device)-1 + positions.requires_grad_() + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, device=device) + box_vectors = torch.tensor([[1, 0, 0], [0,1.1, 0], [0, 0, 1.2]], dtype=torch.float32, device=device) + + # Warmup before capturing graph. + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(3): + e = pme(positions, charges, box_vectors) + e.backward() + torch.cuda.current_stream().wait_stream(s) + + # Capture the graph. + + g = torch.cuda.CUDAGraph() + positions.grad.zero_() + with torch.cuda.graph(g): + e = pme(positions, charges, box_vectors) + e.backward() + + # Replay the graph. + + g.replay() + torch.cuda.synchronize() + + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +def test_double_derivative(device): + """Test that asking for a second derivative throws an excepion.""" + if not torch.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + positions = 3*torch.rand((9, 3), dtype=torch.float32, device=device)-1 + positions.requires_grad_() + charges = torch.tensor([(i-4)*0.1 for i in range(9)], dtype=torch.float32, device=device) + charges.requires_grad_() + box_vectors = torch.tensor([[1, 0, 0], [0,1.1, 0], [0, 0, 1.2]], dtype=torch.float32, device=device) + pme = PME(14, 16, 15, 5, 5.0, 138.935, torch.zeros(9, 0, dtype=torch.int32)) + edir = pme.compute_direct(positions, charges, 0.5, box_vectors) + erecip = pme.compute_reciprocal(positions, charges, box_vectors) + ddir = torch.autograd.grad(edir, positions, retain_graph=True) + drecip = torch.autograd.grad(erecip, positions, retain_graph=True) + with pytest.raises(Exception): + torch.autograd.grad(ddir, positions, retain_graph=True) + with pytest.raises(Exception): + torch.autograd.grad(drecip, positions, retain_graph=True) + with pytest.raises(Exception): + torch.autograd.grad(ddir, charges, retain_graph=True) + with pytest.raises(Exception): + torch.autograd.grad(drecip, charges, retain_graph=True) diff --git a/src/pytorch/pme/__init__.py b/src/pytorch/pme/__init__.py new file mode 100644 index 0000000..0ffd500 --- /dev/null +++ b/src/pytorch/pme/__init__.py @@ -0,0 +1,5 @@ +""" +Particle Mesh Ewald +""" + +from NNPOps.pme.pme import PME \ No newline at end of file diff --git a/src/pytorch/pme/pme.cpp b/src/pytorch/pme/pme.cpp new file mode 100644 index 0000000..54c739d --- /dev/null +++ b/src/pytorch/pme/pme.cpp @@ -0,0 +1,6 @@ +#include + +TORCH_LIBRARY(pme, m) { + m.def("pme_direct(Tensor positions, Tensor charges, Tensor neighbors, Tensor deltas, Tensor distances, Tensor exclusions, Scalar alpha, Scalar coulomb) -> Tensor"); + m.def("pme_reciprocal(Tensor positions, Tensor charges, Tensor box_vectors, Scalar gridx, Scalar gridy, Scalar gridz, Scalar order, Scalar alpha, Scalar coulomb, Tensor xmoduli, Tensor ymoduli, Tensor zmoduli) -> Tensor"); +} diff --git a/src/pytorch/pme/pme.py b/src/pytorch/pme/pme.py new file mode 100644 index 0000000..890b0f6 --- /dev/null +++ b/src/pytorch/pme/pme.py @@ -0,0 +1,196 @@ +from ..neighbors import getNeighborPairs +import torch +import math + +class PME: + """This class implements the Particle Mesh Ewald algorithm (https://doi.org/10.1063/1.470117). + + This is a method of summing all the infinite pairwise electrostatic interactions in a periodic system. It divides + the energy into two parts: a short range term that can be computed efficiently in direct space, and a long range + term that can be computed efficiently in reciprocal space. The individual terms are not physical meaningful, only + their sum. + + This class supports periodic boundary conditions with arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` + must satisfy certain requirements: + + `a[1] = a[2] = b[2] = 0` + `a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff` + `a[0] >= 2*b[0]` + `a[0] >= 2*c[0]` + `b[1] >= 2*c[1]` + + These requirements correspond to a particular rotation of the system and reduced form of the vectors, as well as the + requirement that the cutoff be no larger than half the box width. + + You can optionally specify that certain interactions should be omitted when computing the energy. This is typically + used for nearby atoms within the same molecule. When two atoms are listed as an exclusion, only the interaction of + each with the same periodic copy of the other (that is, not applying periodic boundary conditions) is excluded. + Each atom still interacts with all the periodic copies of the other. + + Due to the way the reciprocal space term is calculated, it is impossible to prevent it from including excluded + interactions. The direct space term therefore compensates for it, subtracting off the energy that was incorrectly + included in reciprocal space. The sum of the two terms thus yields the correct energy with the interaction fully + excluded. + + When performing backpropagation, this class computes derivatives with respect to atomic positions and charges, but + not to any other parameters (box vectors, alpha, etc.). In addition, it only computes first derivatives. + Attempting to compute a second derivative will throw an exception. This means that if you use PME during training, + the loss function can only depend on energy, not forces. + + When you create an instance of this class, you must specify the value of Coulomb's constant 1/(4*pi*eps0). Its + value depends on the units used for energy and distance. The value you specify thus sets the unit system. Here are + the values for some common units. + + kJ/mol, nm: 138.935457 + kJ/mol, A: 1389.35457 + kcal/mol, nm: 33.2063713 + kcal/mol, A: 332.063713 + eV, nm: 1.43996454 + eV, A: 14.3996454 + hartree, bohr: 1.0 + """ + def __init__(self, gridx: int, gridy: int, gridz: int, order: int, alpha: float, coulomb: float, exclusions: torch.Tensor): + """Create an object for computing energies with PME. + + Parameters + ---------- + gridx: int + the size of the charge grid along the x axis + gridy: int + the size of the charge grid along the y axis + gridz: int + the size of the charge grid along the z axis + order: int + the B-spline order to use for charge spreading. With CUDA, only order 4 and 5 are supported. + alpha: float + the coefficient of the erf() function used to separate the energy into direct and reciprocal space terms + coulomb: float + Coulomb's constant 1/(4*pi*eps0). This sets the unit system. + exclusions: torch.Tensor + a tensor of shape `(atoms, max_exclusions)` containing excluded interactions, where `max_exclusions` is the + maximum number of exclusions for any atom. Row `i` lists the indices of all atoms with which atom `i` should + not interact. If an atom has less than `max_exclusions` excluded interactions, set the remaining elements + in the row to -1. The exclusions must be symmetric: if `j` appears in row `i`, then `i` must also appear in + row `j`. If you pass a tensor that does not satisfy that requirement, the results are undefined. + """ + if gridx < 1 or gridy < 1 or gridz < 1: + raise ValueError('The grid dimensions must be positive') + if order < 1: + raise ValueError('order must be positive') + if alpha <= 0: + raise ValueError('alpha must be positive') + if coulomb <= 0: + raise ValueError('coulomb must be positive') + if exclusions.dim() != 2: + raise ValueError('exclusions must be 2D') + self.gridx = gridx + self.gridy = gridy + self.gridz = gridz + self.order = order + self.alpha = alpha + self.coulomb = coulomb + self.exclusions, _ = torch.sort(exclusions.to(torch.int32), descending=True) + + # Initialize the bspline moduli. + + max_size = max(gridx, gridy, gridz) + data = torch.zeros(order, dtype=torch.float32) + ddata = torch.zeros(order, dtype=torch.float32) + bsplines_data = torch.zeros(max_size, dtype=torch.float32) + data[0] = 1 + for i in range(3, order): + data[i-1] = 0 + for j in range(1, i-1): + data[i-j-1] = (j*data[i-j-2]+(i-j)*data[i-j-1])/(i-1) + data[0] /= i-1 + + # Differentiate. + + ddata[0] = -data[0] + ddata[1:order] = data[0:order-1]-data[1:order] + for i in range(1, order-1): + data[order-i-1] = (i*data[order-i-2]+(order-i)*data[order-i-1])/(order-1) + data[0] /= order-1 + bsplines_data[1:order+1] = data + + # Evaluate the actual bspline moduli for X/Y/Z. + + self.moduli = [] + for ndata in (gridx, gridy, gridz): + m = torch.zeros(ndata, dtype=torch.float32) + for i in range(ndata): + arg = (2*torch.pi*i/ndata)*torch.arange(ndata) + sc = torch.sum(bsplines_data[:ndata]*torch.cos(arg)) + ss = torch.sum(bsplines_data[:ndata]*torch.sin(arg)) + m[i] = sc*sc + ss*ss + for i in range(ndata): + if m[i] < 1e-7: + m[i] = (m[(i-1+ndata)%ndata]+m[(i+1)%ndata])*0.5 + self.moduli.append(m) + + def compute_direct(self, positions: torch.Tensor, charges: torch.Tensor, cutoff: float, box_vectors: torch.Tensor, max_num_pairs: int = -1): + """Compute the direct space energy. + + Parameters + ---------- + positions: torch.Tensor + a 2D tensor of shape `(atoms, 3)` containing the positions of the atoms + charges: torch.Tensor + a 1D tensor of length `atoms` containing the charge of each atom + cutoff: float + the cutoff distance to use when computing the direct space term + box_vectors: torch.Tensor + The vectors defining the periodic box. This must have shape `(3, 3)`, + where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. + max_num_pairs: int, optional + Maximum number of pairs (total number of neighbors). If set to `-1` (default), + all possible combinations of atom pairs are included. + + Returns + ------- + the energy of the direct space term + """ + if positions.dim() != 2 or positions.shape[1] != 3: + raise ValueError('positions must have shape (atoms, 3)') + if charges.dim() != 1: + raise ValueError('charges must be 1D') + if positions.shape[0] != self.exclusions.shape[0] or charges.shape[0] != self.exclusions.shape[0]: + raise ValueError('positions, charges, and exclusions must all have the same length') + if box_vectors.dim() != 2 or box_vectors.shape[0] != 3 or box_vectors.shape[1] != 3: + raise ValueError('box_vectors must have shape (3, 3)') + if (cutoff <= 0): + raise ValueError('cutoff must be positive') + neighbors, deltas, distances, number_found_pairs = getNeighborPairs(positions, cutoff, max_num_pairs, box_vectors) + self.exclusions = self.exclusions.to(positions.device) + return torch.ops.pme.pme_direct(positions, charges, neighbors, deltas, distances, self.exclusions, self.alpha, self.coulomb) + + def compute_reciprocal(self, positions: torch.Tensor, charges: torch.Tensor, box_vectors: torch.Tensor): + """Compute the reciprocal space energy. + + Parameters + ---------- + positions: torch.Tensor + a 2D tensor of shape `(atoms, 3)` containing the positions of the atoms + charges: torch.Tensor + a 1D tensor of length `atoms` containing the charge of each atom + box_vectors: torch.Tensor + The vectors defining the periodic box. This must have shape `(3, 3)`, + where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. + + Returns + ------- + the energy of the reciprocal space term + """ + if positions.dim() != 2 or positions.shape[1] != 3: + raise ValueError('positions must have shape (atoms, 3)') + if charges.dim() != 1: + raise ValueError('charges must be 1D') + if positions.shape[0] != self.exclusions.shape[0] or charges.shape[0] != self.exclusions.shape[0]: + raise ValueError('positions, charges, and exclusions must all have the same length') + if box_vectors.dim() != 2 or box_vectors.shape[0] != 3 or box_vectors.shape[1] != 3: + raise ValueError('box_vectors must have shape (3, 3)') + for i in range(3): + self.moduli[i] = self.moduli[i].to(positions.device) + self_energy = -torch.sum(charges**2)*self.coulomb*self.alpha/math.sqrt(torch.pi) + return self_energy + torch.ops.pme.pme_reciprocal(positions, charges, box_vectors, self.gridx, self.gridy, self.gridz, + self.order, self.alpha, self.coulomb, self.moduli[0], self.moduli[1], self.moduli[2]) diff --git a/src/pytorch/pme/pmeCPU.cpp b/src/pytorch/pme/pmeCPU.cpp new file mode 100644 index 0000000..8ca035b --- /dev/null +++ b/src/pytorch/pme/pmeCPU.cpp @@ -0,0 +1,384 @@ +#include +#include +#include + +using namespace std; +using namespace torch::autograd; +using torch::Tensor; +using torch::TensorOptions; +using torch::Scalar; + +static void invertBoxVectors(const Tensor& box_vectors, float recipBoxVectors[3][3]) { + auto box = box_vectors.accessor(); + float determinant = box[0][0]*box[1][1]*box[2][2]; + float scale = 1.0/determinant; + recipBoxVectors[0][0] = box[1][1]*box[2][2]*scale; + recipBoxVectors[0][1] = 0; + recipBoxVectors[0][2] = 0; + recipBoxVectors[1][0] = -box[1][0]*box[2][2]*scale; + recipBoxVectors[1][1] = box[0][0]*box[2][2]*scale; + recipBoxVectors[1][2] = 0; + recipBoxVectors[2][0] = (box[1][0]*box[2][1]-box[1][1]*box[2][0])*scale; + recipBoxVectors[2][1] = -box[0][0]*box[2][1]*scale; + recipBoxVectors[2][2] = box[0][0]*box[1][1]*scale; +} + +static void computeSpline(int atom, const torch::TensorAccessor& pos, const torch::TensorAccessor& box, + const float recipBoxVectors[3][3], const int gridSize[3], int gridIndex[3], vector >& data, + vector >& ddata, int pmeOrder) { + // Find the position relative to the nearest grid point. + + float posInBox[3] = {pos[atom][0], pos[atom][1], pos[atom][2]}; + for (int i = 2; i >= 0; i--) { + float scale = floor(posInBox[i]*recipBoxVectors[i][i]); + for (int j = 0; j < 3; j++) + posInBox[j] -= scale*box[i][j]; + } + float t[3], dr[3]; + int ti[3]; + for (int i = 0; i < 3; i++) { + t[i] = posInBox[0]*recipBoxVectors[0][i] + posInBox[1]*recipBoxVectors[1][i] + posInBox[2]*recipBoxVectors[2][i]; + t[i] = (t[i]-floor(t[i]))*gridSize[i]; + ti[i] = (int) t[i]; + dr[i] = t[i]-ti[i]; + gridIndex[i] = ti[i]%gridSize[i]; + } + + // Compute the B-spline coefficients. + + float scale = 1.0f/(pmeOrder-1); + for (int i = 0; i < 3; i++) { + data[pmeOrder-1][i] = 0; + data[1][i] = dr[i]; + data[0][i] = 1-dr[i]; + for (int j = 3; j < pmeOrder; j++) { + float div = 1.0f/(j-1); + data[j-1][i] = div*dr[i]*data[j-2][i]; + for (int k = 1; k < j-1; k++) + data[j-k-1][i] = div*((dr[i]+k)*data[j-k-2][i]+(j-k-dr[i])*data[j-k-1][i]); + data[0][i] = div*(1-dr[i])*data[0][i]; + } + if (ddata.size() > 0) { + ddata[0][i] = -data[0][i]; + for (int j = 1; j < pmeOrder; j++) + ddata[j][i] = data[j-1][i]-data[j][i]; + } + data[pmeOrder-1][i] = scale*dr[i]*data[pmeOrder-2][i]; + for (int j = 1; j < pmeOrder-1; j++) + data[pmeOrder-j-1][i] = scale*((dr[i]+j)*data[pmeOrder-j-2][i]+(pmeOrder-j-dr[i])*data[pmeOrder-j-1][i]); + data[0][i] = scale*(1-dr[i])*data[0][i]; + } +} + + +class PmeDirectFunctionCpu : public Function { +public: + static Tensor forward(AutogradContext *ctx, + const Tensor& positions, + const Tensor& charges, + const Tensor& neighbors, + const Tensor& deltas, + const Tensor& distances, + const Tensor& exclusions, + const Scalar& alpha_s, + const Scalar& coulomb_s) { + int numAtoms = charges.size(0); + int numPairs = neighbors.size(1); + int maxExclusions = exclusions.size(1); + auto pos = positions.accessor(); + auto pair = neighbors.accessor(); + auto delta = deltas.accessor(); + auto r = distances.accessor(); + auto charge = charges.accessor(); + auto excl = exclusions.accessor(); + float alpha = alpha_s.toDouble(); + float coulomb = coulomb_s.toDouble(); + + // Loop over interactions to compute energy and derivatives. + + TensorOptions options = torch::TensorOptions().device(neighbors.device()); + Tensor posDeriv = torch::zeros({numAtoms, 3}, options); + Tensor chargeDeriv = torch::zeros({numAtoms}, options); + auto posDeriv_a = posDeriv.accessor(); + auto chargeDeriv_a = chargeDeriv.accessor(); + double energy = 0.0; + for (int i = 0; i < numPairs; i++) { + int atom1 = pair[0][i]; + int atom2 = pair[1][i]; + bool include = (atom1 > -1); + for (int j = 0; include && j < maxExclusions && excl[atom1][j] >= atom2; j++) + if (excl[atom1][j] == atom2) + include = false; + if (include) { + float invR = 1/r[i]; + float alphaR = alpha*r[i]; + float expAlphaRSqr = expf(-alphaR*alphaR); + float erfcAlphaR = erfcf(alphaR); + float prefactor = coulomb*invR; + float c1 = charge[atom1]; + float c2 = charge[atom2]; + energy += prefactor*erfcAlphaR*c1*c2; + chargeDeriv_a[atom1] += prefactor*erfcAlphaR*c2; + chargeDeriv_a[atom2] += prefactor*erfcAlphaR*c1; + float dEdR = prefactor*c1*c2*(erfcAlphaR+alphaR*expAlphaRSqr*M_2_SQRTPI)*invR*invR; + for (int j = 0; j < 3; j++) { + posDeriv_a[atom1][j] -= dEdR*delta[i][j]; + posDeriv_a[atom2][j] += dEdR*delta[i][j]; + } + } + } + + // Subtract excluded interactions to compensate for the part that is + // incorrectly added in reciprocal space. + + float dr[3]; + for (int atom1 = 0; atom1 < numAtoms; atom1++) { + for (int i = 0; i < maxExclusions && excl[atom1][i] > atom1; i++) { + int atom2 = excl[atom1][i]; + for (int j = 0; j < 3; j++) + dr[j] = pos[atom1][j]-pos[atom2][j]; + float rr = sqrt(dr[0]*dr[0] + dr[1]*dr[1] + dr[2]*dr[2]); + float invR = 1/rr; + float alphaR = alpha*rr; + float expAlphaRSqr = expf(-alphaR*alphaR); + float erfAlphaR = erff(alphaR); + float prefactor = coulomb*invR; + float c1 = charge[atom1]; + float c2 = charge[atom2]; + energy -= prefactor*erfAlphaR*c1*c2; + chargeDeriv_a[atom1] -= prefactor*erfAlphaR*c2; + chargeDeriv_a[atom2] -= prefactor*erfAlphaR*c1; + float dEdR = prefactor*c1*c2*(erfAlphaR-alphaR*expAlphaRSqr*M_2_SQRTPI)*invR*invR; + for (int j = 0; j < 3; j++) { + posDeriv_a[atom1][j] += dEdR*dr[j]; + posDeriv_a[atom2][j] -= dEdR*dr[j]; + } + } + } + + // Store data for later use. + + ctx->save_for_backward({posDeriv, chargeDeriv}); + return {torch::tensor(energy, options)}; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + Tensor posDeriv = saved[0]; + Tensor chargeDeriv = saved[1]; + torch::Tensor ignore; + return {posDeriv*grad_outputs[0], chargeDeriv*grad_outputs[0], ignore, ignore, ignore, ignore, ignore, ignore}; + } +}; + +class PmeReciprocalFunctionCpu : public Function { +public: + static Tensor forward(AutogradContext *ctx, + const Tensor& positions, + const Tensor& charges, + const Tensor& box_vectors, + const Scalar& gridx, + const Scalar& gridy, + const Scalar& gridz, + const Scalar& order, + const Scalar& alpha, + const Scalar& coulomb, + const Tensor& xmoduli, + const Tensor& ymoduli, + const Tensor& zmoduli) { + int numAtoms = positions.size(0); + int pmeOrder = (int) order.toInt(); + auto box = box_vectors.accessor(); + auto pos = positions.accessor(); + auto charge = charges.accessor(); + int gridSize[3] = {(int) gridx.toInt(), (int) gridy.toInt(), (int) gridz.toInt()}; + float recipBoxVectors[3][3]; + invertBoxVectors(box_vectors, recipBoxVectors); + vector grid(gridSize[0]*gridSize[1]*gridSize[2], 0); + float sqrtCoulomb = sqrt(coulomb.toDouble()); + + // Spread the charge on the grid. + + for (int atom = 0; atom < numAtoms; atom++) { + // Compute the B-spline coefficients. + + int gridIndex[3]; + vector > data(pmeOrder), ddata; + computeSpline(atom, pos, box, recipBoxVectors, gridSize, gridIndex, data, ddata, pmeOrder); + + // Spread the charge from this atom onto each grid point. + + for (int ix = 0; ix < pmeOrder; ix++) { + int xindex = (gridIndex[0]+ix) % gridSize[0]; + float dx = charge[atom]*sqrtCoulomb*data[ix][0]; + for (int iy = 0; iy < pmeOrder; iy++) { + int yindex = (gridIndex[1]+iy) % gridSize[1]; + float dxdy = dx*data[iy][1]; + for (int iz = 0; iz < pmeOrder; iz++) { + int zindex = (gridIndex[2]+iz) % gridSize[2]; + int index = xindex*gridSize[1]*gridSize[2] + yindex*gridSize[2] + zindex; + grid[index] += dxdy*data[iz][2]; + } + } + } + } + + // Take the Fourier transform. + + TensorOptions options = torch::TensorOptions().device(positions.device()); // Data type of float by default + Tensor realGrid = torch::from_blob(grid.data(), {gridSize[0], gridSize[1], gridSize[2]}, options); + Tensor recipGrid = torch::fft::rfftn(realGrid); + auto recip = recipGrid.accessor,3>(); + + // Perform the convolution and calculate the energy. + + double energy = 0.0; + int zsize = gridSize[2]/2+1; + int yzsize = gridSize[1]*zsize; + float scaleFactor = (float) (M_PI*box[0][0]*box[1][1]*box[2][2]); + float recipExpFactor = (float) (M_PI*M_PI/(alpha.toDouble()*alpha.toDouble())); + auto xmod = xmoduli.accessor(); + auto ymod = ymoduli.accessor(); + auto zmod = zmoduli.accessor(); + for (int kx = 0; kx < gridSize[0]; kx++) { + int mx = (kx < (gridSize[0]+1)/2) ? kx : kx-gridSize[0]; + float mhx = mx*recipBoxVectors[0][0]; + float bx = scaleFactor*xmod[kx]; + for (int ky = 0; ky < gridSize[1]; ky++) { + int my = (ky < (gridSize[1]+1)/2) ? ky : ky-gridSize[1]; + float mhy = mx*recipBoxVectors[1][0] + my*recipBoxVectors[1][1]; + float mhx2y2 = mhx*mhx + mhy*mhy; + float bxby = bx*ymod[ky]; + for (int kz = 0; kz < zsize; kz++) { + int index = kx*yzsize + ky*zsize + kz; + int mz = (kz < (gridSize[2]+1)/2) ? kz : kz-gridSize[2]; + float mhz = mx*recipBoxVectors[2][0] + my*recipBoxVectors[2][1] + mz*recipBoxVectors[2][2]; + float bz = zmod[kz]; + float m2 = mhx2y2 + mhz*mhz; + float denom = m2*bxby*bz; + float eterm = (index == 0 ? 0 : expf(-recipExpFactor*m2)/denom); + float scale = (kz > 0 && kz <= (gridSize[2]-1)/2 ? 2 : 1); + c10::complex& g = recip[kx][ky][kz]; + energy += scale * eterm * (g.real()*g.real() + g.imag()*g.imag()); + g *= eterm; + } + } + } + + // Store data for later use. + + ctx->save_for_backward({positions, charges, box_vectors, xmoduli, ymoduli, zmoduli, recipGrid}); + ctx->saved_data["gridx"] = gridx; + ctx->saved_data["gridy"] = gridy; + ctx->saved_data["gridz"] = gridz; + ctx->saved_data["order"] = order; + ctx->saved_data["alpha"] = alpha; + ctx->saved_data["coulomb"] = coulomb; + return {torch::tensor(0.5*energy, options)}; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + Tensor positions = saved[0]; + Tensor charges = saved[1]; + Tensor box_vectors = saved[2]; + Tensor xmoduli = saved[3]; + Tensor ymoduli = saved[4]; + Tensor zmoduli = saved[5]; + Tensor recipGrid = saved[6]; + int gridSize[3] = {(int) ctx->saved_data["gridx"].toInt(), (int) ctx->saved_data["gridy"].toInt(), (int) ctx->saved_data["gridz"].toInt()}; + int pmeOrder = (int) ctx->saved_data["order"].toInt(); + float alpha = (float) ctx->saved_data["alpha"].toDouble(); + float sqrtCoulomb = sqrt(ctx->saved_data["coulomb"].toDouble()); + int numAtoms = positions.size(0); + auto box = box_vectors.accessor(); + auto pos = positions.accessor(); + auto charge = charges.accessor(); + float recipBoxVectors[3][3]; + invertBoxVectors(box_vectors, recipBoxVectors); + + // Take the inverse Fourier transform. + + int64_t targetGridSize[3] = {gridSize[0], gridSize[1], gridSize[2]}; + Tensor realGrid = torch::fft::irfftn(recipGrid, targetGridSize, c10::nullopt, "forward"); + auto grid = realGrid.accessor(); + + // Compute the derivatives. + + TensorOptions options = torch::TensorOptions().device(positions.device()); // Data type of float by default + Tensor posDeriv = torch::empty({numAtoms, 3}, options); + Tensor chargeDeriv = torch::empty({numAtoms}, options); + auto posDeriv_a = posDeriv.accessor(); + auto chargeDeriv_a = chargeDeriv.accessor(); + for (int atom = 0; atom < numAtoms; atom++) { + // Compute the B-spline coefficients. + + int gridIndex[3]; + vector > data(pmeOrder), ddata(pmeOrder); + computeSpline(atom, pos, box, recipBoxVectors, gridSize, gridIndex, data, ddata, pmeOrder); + + // Compute the derivatives on this atom. + + float dpos[3] = {0, 0, 0}; + float dq = 0; + for (int ix = 0; ix < pmeOrder; ix++) { + int xindex = (gridIndex[0]+ix) % gridSize[0]; + float dx = data[ix][0]; + float ddx = ddata[ix][0]; + for (int iy = 0; iy < pmeOrder; iy++) { + int yindex = (gridIndex[1]+iy) % gridSize[1]; + float dy = data[iy][1]; + float ddy = ddata[iy][1]; + for (int iz = 0; iz < pmeOrder; iz++) { + int zindex = (gridIndex[2]+iz) % gridSize[2]; + float dz = data[iz][2]; + float ddz = ddata[iz][2]; + float g = grid[xindex][yindex][zindex]; + dpos[0] += ddx*dy*dz*g; + dpos[1] += dx*ddy*dz*g; + dpos[2] += dx*dy*ddz*g; + dq += dx*dy*dz*g; + } + } + } + float scale = charge[atom]*sqrtCoulomb; + posDeriv_a[atom][0] = scale*(dpos[0]*gridSize[0]*recipBoxVectors[0][0]); + posDeriv_a[atom][1] = scale*(dpos[0]*gridSize[0]*recipBoxVectors[1][0] + dpos[1]*gridSize[1]*recipBoxVectors[1][1]); + posDeriv_a[atom][2] = scale*(dpos[0]*gridSize[0]*recipBoxVectors[2][0] + dpos[1]*gridSize[1]*recipBoxVectors[2][1] + dpos[2]*gridSize[2]*recipBoxVectors[2][2]); + chargeDeriv_a[atom] = dq*sqrtCoulomb; + } + torch::Tensor ignore; + return {posDeriv*grad_outputs[0], chargeDeriv*grad_outputs[0], ignore, ignore, ignore, ignore, ignore, ignore, ignore, ignore, ignore, ignore}; + } +}; + +Tensor pme_direct_cpu(const Tensor& positions, + const Tensor& charges, + const Tensor& neighbors, + const Tensor& deltas, + const Tensor& distances, + const Tensor& exclusions, + const Scalar& alpha, + const Scalar& coulomb) { + return PmeDirectFunctionCpu::apply(positions, charges, neighbors, deltas, distances, exclusions, alpha, coulomb); +} + +Tensor pme_reciprocal_cpu(const Tensor& positions, + const Tensor& charges, + const Tensor& box_vectors, + const Scalar& gridx, + const Scalar& gridy, + const Scalar& gridz, + const Scalar& order, + const Scalar& alpha, + const Scalar& coulomb, + const Tensor& xmoduli, + const Tensor& ymoduli, + const Tensor& zmoduli) { + return PmeReciprocalFunctionCpu::apply(positions, charges, box_vectors, gridx, gridy, gridz, order, alpha, coulomb, xmoduli, ymoduli, zmoduli); +} + +TORCH_LIBRARY_IMPL(pme, CPU, m) { + m.impl("pme_direct", pme_direct_cpu); + m.impl("pme_reciprocal", pme_reciprocal_cpu); +} diff --git a/src/pytorch/pme/pmeCUDA.cu b/src/pytorch/pme/pmeCUDA.cu new file mode 100644 index 0000000..2d78091 --- /dev/null +++ b/src/pytorch/pme/pmeCUDA.cu @@ -0,0 +1,449 @@ +#include +#include +#include +#include +#include +#include + +#include "common/accessor.cuh" + +using namespace std; +using namespace torch::autograd; +using torch::Tensor; +using torch::TensorOptions; +using torch::Scalar; + +#define CHECK_RESULT(result) \ + if (result != cudaSuccess) { \ + throw runtime_error(string("Encountered error ")+cudaGetErrorName(result)+" at "+__FILE__+":"+to_string(__LINE__));\ + } + +static int getMaxBlocks() { + // Get an upper limit on how many thread blocks we try to launch based on the size of the GPU. + + int device, numMultiprocessors; + CHECK_RESULT(cudaGetDevice(&device)); + CHECK_RESULT(cudaDeviceGetAttribute(&numMultiprocessors, cudaDevAttrMultiProcessorCount, device)); + return numMultiprocessors*4; +} + +__global__ void computeDirect(const Accessor pos, const Accessor charge, Accessor neighbors, const Accessor deltas, + const Accessor distances, const Accessor exclusions, Accessor posDeriv, + Accessor chargeDeriv, Accessor energyBuffer, float alpha, float coulomb) { + int numAtoms = pos.size(0); + int numNeighbors = neighbors.size(1); + int maxExclusions = exclusions.size(1); + double energy = 0; + + for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < numNeighbors; index += blockDim.x*gridDim.x) { + int atom1 = neighbors[0][index]; + int atom2 = neighbors[1][index]; + float r = distances[index]; + bool include = (atom1 > -1); + for (int j = 0; include && j < maxExclusions && exclusions[atom1][j] >= atom2; j++) + if (exclusions[atom1][j] == atom2) + include = false; + if (include) { + float invR = 1/r; + float alphaR = alpha*r; + float expAlphaRSqr = expf(-alphaR*alphaR); + float erfcAlphaR = erfcf(alphaR); + float prefactor = coulomb*invR; + float c1 = charge[atom1]; + float c2 = charge[atom2]; + energy += prefactor*erfcAlphaR*c1*c2; + atomicAdd(&chargeDeriv[atom1], prefactor*erfcAlphaR*c2); + atomicAdd(&chargeDeriv[atom2], prefactor*erfcAlphaR*c1); + float dEdR = prefactor*c1*c2*(erfcAlphaR+alphaR*expAlphaRSqr*M_2_SQRTPI)*invR*invR; + for (int j = 0; j < 3; j++) { + atomicAdd(&posDeriv[atom1][j], -dEdR*deltas[index][j]); + atomicAdd(&posDeriv[atom2][j], dEdR*deltas[index][j]); + } + } + } + + // Subtract excluded interactions to compensate for the part that is + // incorrectly added in reciprocal space. + + float dr[3]; + for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < numAtoms*maxExclusions; index += blockDim.x*gridDim.x) { + int atom1 = index/maxExclusions; + int atom2 = exclusions[atom1][index-atom1*maxExclusions]; + if (atom2 > atom1) { + for (int j = 0; j < 3; j++) + dr[j] = pos[atom1][j]-pos[atom2][j]; + float r2 = dr[0]*dr[0] + dr[1]*dr[1] + dr[2]*dr[2]; + float invR = rsqrtf(r2); + float r = invR*r2; + float alphaR = alpha*r; + float expAlphaRSqr = expf(-alphaR*alphaR); + float erfAlphaR = erff(alphaR); + float prefactor = coulomb*invR; + float c1 = charge[atom1]; + float c2 = charge[atom2]; + energy -= prefactor*erfAlphaR*c1*c2; + atomicAdd(&chargeDeriv[atom1], -prefactor*erfAlphaR*c2); + atomicAdd(&chargeDeriv[atom2], -prefactor*erfAlphaR*c1); + float dEdR = prefactor*c1*c2*(erfAlphaR-alphaR*expAlphaRSqr*M_2_SQRTPI)*invR*invR; + for (int j = 0; j < 3; j++) { + atomicAdd(&posDeriv[atom1][j], dEdR*dr[j]); + atomicAdd(&posDeriv[atom2][j], -dEdR*dr[j]); + } + } + } + energyBuffer[blockIdx.x*blockDim.x+threadIdx.x] = energy; +} + +__device__ void invertBoxVectors(const Accessor& box, float recipBoxVectors[3][3]) { + float determinant = box[0][0]*box[1][1]*box[2][2]; + float scale = 1.0f/determinant; + recipBoxVectors[0][0] = box[1][1]*box[2][2]*scale; + recipBoxVectors[0][1] = 0; + recipBoxVectors[0][2] = 0; + recipBoxVectors[1][0] = -box[1][0]*box[2][2]*scale; + recipBoxVectors[1][1] = box[0][0]*box[2][2]*scale; + recipBoxVectors[1][2] = 0; + recipBoxVectors[2][0] = (box[1][0]*box[2][1]-box[1][1]*box[2][0])*scale; + recipBoxVectors[2][1] = -box[0][0]*box[2][1]*scale; + recipBoxVectors[2][2] = box[0][0]*box[1][1]*scale; +} + +__device__ void computeSpline(int atom, const Accessor pos, const Accessor box, + const float recipBoxVectors[3][3], const int gridSize[3], int gridIndex[3], float data[][3], + float ddata[][3], int pmeOrder) { + // Find the position relative to the nearest grid point. + + float posInBox[3] = {pos[atom][0], pos[atom][1], pos[atom][2]}; + for (int i = 2; i >= 0; i--) { + float scale = floor(posInBox[i]*recipBoxVectors[i][i]); + for (int j = 0; j < 3; j++) + posInBox[j] -= scale*box[i][j]; + } + float t[3], dr[3]; + int ti[3]; + for (int i = 0; i < 3; i++) { + t[i] = posInBox[0]*recipBoxVectors[0][i] + posInBox[1]*recipBoxVectors[1][i] + posInBox[2]*recipBoxVectors[2][i]; + t[i] = (t[i]-floor(t[i]))*gridSize[i]; + ti[i] = (int) t[i]; + dr[i] = t[i]-ti[i]; + gridIndex[i] = ti[i]%gridSize[i]; + } + + // Compute the B-spline coefficients. + + float scale = 1.0f/(pmeOrder-1); + for (int i = 0; i < 3; i++) { + data[pmeOrder-1][i] = 0; + data[1][i] = dr[i]; + data[0][i] = 1-dr[i]; + for (int j = 3; j < pmeOrder; j++) { + float div = 1.0f/(j-1); + data[j-1][i] = div*dr[i]*data[j-2][i]; + for (int k = 1; k < j-1; k++) + data[j-k-1][i] = div*((dr[i]+k)*data[j-k-2][i]+(j-k-dr[i])*data[j-k-1][i]); + data[0][i] = div*(1-dr[i])*data[0][i]; + } + if (ddata != NULL) { + ddata[0][i] = -data[0][i]; + for (int j = 1; j < pmeOrder; j++) + ddata[j][i] = data[j-1][i]-data[j][i]; + } + data[pmeOrder-1][i] = scale*dr[i]*data[pmeOrder-2][i]; + for (int j = 1; j < pmeOrder-1; j++) + data[pmeOrder-j-1][i] = scale*((dr[i]+j)*data[pmeOrder-j-2][i]+(pmeOrder-j-dr[i])*data[pmeOrder-j-1][i]); + data[0][i] = scale*(1-dr[i])*data[0][i]; + } +} + +template +__global__ void spreadCharge(const Accessor pos, const Accessor charge, const Accessor box, + Accessor grid, int gridx, int gridy, int gridz, float sqrtCoulomb) { + __shared__ float recipBoxVectors[3][3]; + if (threadIdx.x == 0) + invertBoxVectors(box, recipBoxVectors); + __syncthreads(); + float data[PME_ORDER][3]; + int numAtoms = pos.size(0); + for (int atom = blockIdx.x*blockDim.x+threadIdx.x; atom < numAtoms; atom += blockDim.x*gridDim.x) { + int gridIndex[3]; + int gridSize[3] = {gridx, gridy, gridz}; + computeSpline(atom, pos, box,recipBoxVectors, gridSize, gridIndex, data, NULL, PME_ORDER); + + // Spread the charge from this atom onto each grid point. + + for (int ix = 0; ix < PME_ORDER; ix++) { + int xindex = gridIndex[0]+ix; + xindex -= (xindex >= gridx ? gridx : 0); + float dx = charge[atom]*sqrtCoulomb*data[ix][0]; + for (int iy = 0; iy < PME_ORDER; iy++) { + int yindex = gridIndex[1]+iy; + yindex -= (yindex >= gridy ? gridy : 0); + float dxdy = dx*data[iy][1]; + for (int iz = 0; iz < PME_ORDER; iz++) { + int zindex = gridIndex[2]+iz; + zindex -= (zindex >= gridz ? gridz : 0); + atomicAdd(&grid[xindex][yindex][zindex], dxdy*data[iz][2]); + } + } + } + } +} + +__global__ void reciprocalConvolution(const Accessor box, Accessor, 3> grid, int gridx, int gridy, int gridz, + const Accessor xmoduli, const Accessor ymoduli, const Accessor zmoduli, + float recipExpFactor, Accessor energyBuffer) { + float recipBoxVectors[3][3]; + invertBoxVectors(box, recipBoxVectors); + const unsigned int gridSize = gridx*gridy*(gridz/2+1); + const float recipScaleFactor = recipBoxVectors[0][0]*recipBoxVectors[1][1]*recipBoxVectors[2][2]/M_PI; + double energy = 0; + + for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) { + int kx = index/(gridy*(gridz/2+1)); + int remainder = index-kx*gridy*(gridz/2+1); + int ky = remainder/(gridz/2+1); + int kz = remainder-ky*(gridz/2+1); + int mx = (kx < (gridx+1)/2) ? kx : (kx-gridx); + int my = (ky < (gridy+1)/2) ? ky : (ky-gridy); + int mz = (kz < (gridz+1)/2) ? kz : (kz-gridz); + float mhx = mx*recipBoxVectors[0][0]; + float mhy = mx*recipBoxVectors[1][0]+my*recipBoxVectors[1][1]; + float mhz = mx*recipBoxVectors[2][0]+my*recipBoxVectors[2][1]+mz*recipBoxVectors[2][2]; + float bx = xmoduli[kx]; + float by = ymoduli[ky]; + float bz = zmoduli[kz]; + c10::complex& g = grid[kx][ky][kz]; + float m2 = mhx*mhx+mhy*mhy+mhz*mhz; + float denom = m2*bx*by*bz; + float eterm = (index == 0 ? 0 : recipScaleFactor*exp(-recipExpFactor*m2)/denom); + float scale = (kz > 0 && kz <= (gridz-1)/2 ? 2 : 1); + energy += scale * eterm * (g.real()*g.real() + g.imag()*g.imag()); + g *= eterm; + } + energyBuffer[blockIdx.x*blockDim.x+threadIdx.x] = 0.5f*energy; +} + +template +__global__ void interpolateForce(const Accessor pos, const Accessor charge, const Accessor box, + const Accessor grid, int gridx, int gridy, int gridz, float sqrtCoulomb, + Accessor posDeriv, Accessor chargeDeriv) { + __shared__ float recipBoxVectors[3][3]; + if (threadIdx.x == 0) + invertBoxVectors(box, recipBoxVectors); + __syncthreads(); + float data[PME_ORDER][3]; + float ddata[PME_ORDER][3]; + int numAtoms = pos.size(0); + + for (int atom = blockIdx.x*blockDim.x+threadIdx.x; atom < numAtoms; atom += blockDim.x*gridDim.x) { + int gridIndex[3]; + int gridSize[3] = {gridx, gridy, gridz}; + computeSpline(atom, pos, box,recipBoxVectors, gridSize, gridIndex, data, ddata, PME_ORDER); + + // Compute the derivatives on this atom. + + float dpos[3] = {0, 0, 0}; + float dq = 0; + for (int ix = 0; ix < PME_ORDER; ix++) { + int xindex = gridIndex[0]+ix; + xindex -= (xindex >= gridx ? gridx : 0); + float dx = data[ix][0]; + float ddx = ddata[ix][0]; + for (int iy = 0; iy < PME_ORDER; iy++) { + int yindex = gridIndex[1]+iy; + yindex -= (yindex >= gridy ? gridy : 0); + float dy = data[iy][1]; + float ddy = ddata[iy][1]; + for (int iz = 0; iz < PME_ORDER; iz++) { + int zindex = gridIndex[2]+iz; + zindex -= (zindex >= gridz ? gridz : 0); + float dz = data[iz][2]; + float ddz = ddata[iz][2]; + float g = grid[xindex][yindex][zindex]; + dpos[0] += ddx*dy*dz*g; + dpos[1] += dx*ddy*dz*g; + dpos[2] += dx*dy*ddz*g; + dq += dx*dy*dz*g; + } + } + } + float scale = charge[atom]*sqrtCoulomb; + posDeriv[atom][0] = scale*(dpos[0]*gridSize[0]*recipBoxVectors[0][0]); + posDeriv[atom][1] = scale*(dpos[0]*gridSize[0]*recipBoxVectors[1][0] + dpos[1]*gridSize[1]*recipBoxVectors[1][1]); + posDeriv[atom][2] = scale*(dpos[0]*gridSize[0]*recipBoxVectors[2][0] + dpos[1]*gridSize[1]*recipBoxVectors[2][1] + dpos[2]*gridSize[2]*recipBoxVectors[2][2]); + chargeDeriv[atom] = dq*sqrtCoulomb; + } +} + +class PmeDirectFunctionCuda : public Function { +public: + static Tensor forward(AutogradContext *ctx, + const Tensor& positions, + const Tensor& charges, + const Tensor& neighbors, + const Tensor& deltas, + const Tensor& distances, + const Tensor& exclusions, + const Scalar& alpha, + const Scalar& coulomb) { + const auto stream = c10::cuda::getCurrentCUDAStream(positions.get_device()); + const c10::cuda::CUDAStreamGuard guard(stream); + int numAtoms = charges.size(0); + int numPairs = neighbors.size(1); + TensorOptions options = torch::TensorOptions().device(neighbors.device()); + Tensor posDeriv = torch::zeros({numAtoms, 3}, options); + Tensor chargeDeriv = torch::zeros({numAtoms}, options); + int blockSize = 128; + int numBlocks = max(1, min(getMaxBlocks(), (numPairs+blockSize-1)/blockSize)); + Tensor energy = torch::zeros(numBlocks*blockSize, options); + computeDirect<<>>(get_accessor(positions), get_accessor(charges), + get_accessor(neighbors), get_accessor(deltas), get_accessor(distances), + get_accessor(exclusions), get_accessor(posDeriv), get_accessor(chargeDeriv), + get_accessor(energy), alpha.toDouble(), coulomb.toDouble()); + + // Store data for later use. + + ctx->save_for_backward({posDeriv, chargeDeriv}); + return {torch::sum(energy)}; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + Tensor posDeriv = saved[0]; + Tensor chargeDeriv = saved[1]; + torch::Tensor ignore; + return {posDeriv*grad_outputs[0], chargeDeriv*grad_outputs[0], ignore, ignore, ignore, ignore, ignore, ignore}; + } +}; + +class PmeFunctionCuda : public Function { +public: + static Tensor forward(AutogradContext *ctx, + const Tensor& positions, + const Tensor& charges, + const Tensor& box_vectors, + const Scalar& gridx, + const Scalar& gridy, + const Scalar& gridz, + const Scalar& order, + const Scalar& alpha, + const Scalar& coulomb, + const Tensor& xmoduli, + const Tensor& ymoduli, + const Tensor& zmoduli) { + const auto stream = c10::cuda::getCurrentCUDAStream(positions.get_device()); + const c10::cuda::CUDAStreamGuard guard(stream); + int numAtoms = positions.size(0); + int pmeOrder = (int) order.toInt(); + int gridSize[3] = {(int) gridx.toInt(), (int) gridy.toInt(), (int) gridz.toInt()}; + float sqrtCoulomb = sqrt(coulomb.toDouble()); + + // Spread the charge on the grid. + + TensorOptions options = torch::TensorOptions().device(positions.device()); + Tensor realGrid = torch::zeros({gridSize[0], gridSize[1], gridSize[2]}, options); + int blockSize = 128; + int numBlocks = max(1, min(getMaxBlocks(), (numAtoms+blockSize-1)/blockSize)); + TORCH_CHECK(pmeOrder == 4 || pmeOrder == 5, "Only pmeOrder 4 or 5 is supported with CUDA"); + auto spread = (pmeOrder == 4 ? spreadCharge<4> : spreadCharge<5>); + spread<<>>(get_accessor(positions), get_accessor(charges), + get_accessor(box_vectors), get_accessor(realGrid), gridSize[0], gridSize[1], gridSize[2], sqrtCoulomb); + + // Take the Fourier transform. + + Tensor recipGrid = torch::fft::rfftn(realGrid); + + // Perform the convolution and calculate the energy. + + Tensor energy = torch::zeros(numBlocks*blockSize, options); + reciprocalConvolution<<>>(get_accessor(box_vectors), get_accessor, 3>(recipGrid), + gridSize[0], gridSize[1], gridSize[2], get_accessor(xmoduli), get_accessor(ymoduli), get_accessor(zmoduli), + M_PI*M_PI/(alpha.toDouble()*alpha.toDouble()), get_accessor(energy)); + + // Store data for later use. + + ctx->save_for_backward({positions, charges, box_vectors, xmoduli, ymoduli, zmoduli, recipGrid}); + ctx->saved_data["gridx"] = gridx; + ctx->saved_data["gridy"] = gridy; + ctx->saved_data["gridz"] = gridz; + ctx->saved_data["order"] = order; + ctx->saved_data["alpha"] = alpha; + ctx->saved_data["coulomb"] = coulomb; + return {torch::sum(energy)}; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + Tensor positions = saved[0]; + Tensor charges = saved[1]; + Tensor box_vectors = saved[2]; + Tensor xmoduli = saved[3]; + Tensor ymoduli = saved[4]; + Tensor zmoduli = saved[5]; + Tensor recipGrid = saved[6]; + int gridSize[3] = {(int) ctx->saved_data["gridx"].toInt(), (int) ctx->saved_data["gridy"].toInt(), (int) ctx->saved_data["gridz"].toInt()}; + int pmeOrder = (int) ctx->saved_data["order"].toInt(); + float alpha = (float) ctx->saved_data["alpha"].toDouble(); + float sqrtCoulomb = sqrt(ctx->saved_data["coulomb"].toDouble()); + const auto stream = c10::cuda::getCurrentCUDAStream(positions.get_device()); + const c10::cuda::CUDAStreamGuard guard(stream); + int numAtoms = positions.size(0); + + // Take the inverse Fourier transform. + + int64_t targetGridSize[3] = {gridSize[0], gridSize[1], gridSize[2]}; + Tensor realGrid = torch::fft::irfftn(recipGrid, targetGridSize, c10::nullopt, "forward"); + + // Compute the derivatives. + + TensorOptions options = torch::TensorOptions().device(positions.device()); + Tensor posDeriv = torch::empty({numAtoms, 3}, options); + Tensor chargeDeriv = torch::empty({numAtoms}, options); + int blockSize = 128; + int numBlocks = max(1, min(getMaxBlocks(), (numAtoms+blockSize-1)/blockSize)); + TORCH_CHECK(pmeOrder == 4 || pmeOrder == 5, "Only pmeOrder 4 or 5 is supported with CUDA"); + if (pmeOrder == 4) + interpolateForce<4><<>>(get_accessor(positions), get_accessor(charges), + get_accessor(box_vectors), get_accessor(realGrid), gridSize[0], gridSize[1], gridSize[2], sqrtCoulomb, + get_accessor(posDeriv), get_accessor(chargeDeriv)); + else + interpolateForce<5><<>>(get_accessor(positions), get_accessor(charges), + get_accessor(box_vectors), get_accessor(realGrid), gridSize[0], gridSize[1], gridSize[2], sqrtCoulomb, + get_accessor(posDeriv), get_accessor(chargeDeriv)); + posDeriv *= grad_outputs[0]; + chargeDeriv *= grad_outputs[0]; + torch::Tensor ignore; + return {posDeriv, chargeDeriv, ignore, ignore, ignore, ignore, ignore, ignore, ignore, ignore, ignore, ignore}; + } +}; + +Tensor pme_direct_cuda(const Tensor& positions, + const Tensor& charges, + const Tensor& neighbors, + const Tensor& deltas, + const Tensor& distances, + const Tensor& exclusions, + const Scalar& alpha, + const Scalar& coulomb) { + return PmeDirectFunctionCuda::apply(positions, charges, neighbors, deltas, distances, exclusions, alpha, coulomb); +} + +Tensor pme_reciprocal_cuda(const Tensor& positions, + const Tensor& charges, + const Tensor& box_vectors, + const Scalar& gridx, + const Scalar& gridy, + const Scalar& gridz, + const Scalar& order, + const Scalar& alpha, + const Scalar& coulomb, + const Tensor& xmoduli, + const Tensor& ymoduli, + const Tensor& zmoduli) { + return PmeFunctionCuda::apply(positions, charges, box_vectors, gridx, gridy, gridz, order, alpha, coulomb, xmoduli, ymoduli, zmoduli); +} + +TORCH_LIBRARY_IMPL(pme, AutogradCUDA, m) { + m.impl("pme_direct", pme_direct_cuda); + m.impl("pme_reciprocal", pme_reciprocal_cuda); +}