forked from noegroup/torchmd-autodiff
-
Notifications
You must be signed in to change notification settings - Fork 5
/
waterbox.py
73 lines (60 loc) · 2.18 KB
/
waterbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Water box from the torchmd test suite.
"""
import os
import torch
import numpy as np
from moleculekit.molecule import Molecule
import parmed
import torchmd
from torchmd.forcefields.forcefield import ForceField
from torchmd.parameters import Parameters
from torchmd.forces import Forces
from torchmd.systems import System
from torchmd.integrator import maxwell_boltzmann
class WaterBox:
"""
Water box for use in torchmd.
"""
TESTDIR = os.path.normpath(
os.path.join(
os.path.dirname(torchmd.__file__),
os.pardir,
"test-data/waterbox"
)
)
psfFile = os.path.join(TESTDIR, "structure.psf")
xtcFile = os.path.join(TESTDIR, "output.xtc")
prmFiles = [os.path.join(TESTDIR, "parameters.prm")]
def __init__(self, nreplicas=1, T=300.0, dtype=torch.double, device="cpu"):
self.dtype = dtype
self.device = device
self.mol = self._init_mol()
self.forces, self.ff = self._init_forces(self.mol)
self.system = self._init_system(self.mol, self.forces, nreplicas=nreplicas, T=300.0)
def _init_mol(self):
mol = Molecule(self.psfFile)
mol.read([self.xtcFile])
mol.dropFrames(keep=0)
return mol
def _init_system(self, mol, forces, nreplicas, T):
system = System(mol.numAtoms, nreplicas, self.dtype, self.device)
system.set_positions(mol.coords)
system.set_box(mol.box)
system.set_velocities(maxwell_boltzmann(forces.par.masses, T=T, replicas=nreplicas))
return system
def _init_forces(self, mol):
coords = mol.coords
coords = coords[:, :, 0].squeeze()
cutoff = 9.0 #np.min(mol.box) / 2 - 0.01
switch_dist = 7.5
rfa = True
struct = parmed.charmm.CharmmPsfFile(self.psfFile)
prm = parmed.charmm.CharmmParameterSet(*self.prmFiles)
prm_org = parmed.charmm.CharmmParameterSet(*self.prmFiles)
ff = ForceField.create(mol, prm)
parameters = Parameters(ff, mol, precision=self.dtype, device=self.device)
forces = Forces(
parameters, cutoff=cutoff, switch_dist=switch_dist, rfa=rfa,
)
return forces, ff