-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make OptimizedTorchANI robust to changes to device between calls. #113
Conversation
Ensure position and species are on the same device in OptimizedTorchANI Format
I'm not sure this is the right fix. I think the real problem is in OpenMM-Torch, and this just masks the symptoms. TorchForce used to store the name of the file containing the model. When you created a Context (and therefore a TorchForceImpl), it would load the file to create a openmm/openmm-torch#97 changed it to make TorchForce directly store the I think the correct solution is to make TorchForceImpl clone the model. That will ensure that every Context again has its own independent copy. |
I think it will require the torch model code to handle the devices as done by @RaulPPelaez here. How would it work for a simple pure pytorch model, e.g.: import torch as pt
from openmm import Context, LocalEnergyMinimizer, Platform, System, VerletIntegrator
from openmmtorch import TorchForce
scale = 1.0e10
platform = "CUDA"
device = "cuda"
class Model(pt.nn.Module):
def __init__(self, scale, device):
super().__init__()
self.device=device
self.scale = scale
self.r0 = pt.tensor([0.0,0.0,0.0], device=device)
def forward(self, positions):
positions=positions.to(self.device) # <- without this line it will not work
return self.scale * pt.sum(positions - self.r0)**2
model = pt.jit.script(Model(scale, device))
force = TorchForce(model)
system = System()
system.addForce(force)
for _ in range(2):
system.addParticle(1)
platform = Platform.getPlatformByName(platform)
context = Context(system, VerletIntegrator(1), platform)
context.setPositions([[0, 0, 0], [1, 0, 0]])
LocalEnergyMinimizer.minimize(context) Typically you have to define the device for some tensors in the constructor. In the forward method you then expect the positions to be on the same device. If they then come on CPU instead of CUDA because LocalEnergyMinimizer has detected the forces are large you will need to have code in the forward method which copies the tensors to all be on the same device. Or is there a different way to do this without the |
@RaulPPelaez, I think @peastman is right. Each context should have a separate copy of Torch module. So, it can be initialized once on a specific device and never changes. This will ensure the isolation of the contexts and the NNPops don't need to handle the device changes. |
@sef43 the module shouldn't have explicit device assignments. Rather you create parameters and/or buffers, so PyTorch can move them to the right device. OpenMM-Torch already uses that mechanism (https://github.com/openmm/openmm-torch/blob/e9f2ae24f00138740ee6683ea4ccd476c268c183/platforms/cuda/src/CudaTorchKernels.cpp#L78). |
I believe I am missing something about the issue @peastman is describing. const torch::jit::Module TorchForce::getModule() const {
std::stringstream output_stream;
this->module.save(output_stream);
return torch::jit::load(output_stream);
} This way TorchForceImpl::initialize gets a just loaded module each time: void TorchForceImpl::initialize(ContextImpl& context) {
auto module = owner.getModule();
// Create the kernel.
kernel = context.getPlatform().createKernel(CalcTorchForceKernel::Name(), context);
kernel.getAs<CalcTorchForceKernel>().initialize(context.getSystem(), owner, module);
} As far as I understand this is equivalent to the behavior of TorchForceImpl before openmm/openmm-torch#97 . EDIT: I made a mistake, the fix above does actually fix the error and it makes sense to me why. |
I opened openmm/openmm-torch#116 with the fix suggested by @peastman. Hence, while this PR does make SymmetryFunctions robust to devices changing, I am not sure if it is worth merging it. |
Solves #112.
AFAIK, the error in #112 consists of OpenMM changing the location of just the positions, which ends up with NNPOps down the line being fed tensors in two different devices.
The original reproducer by @raimis:
Is solved by simply sending the positions to the same device as the tensor with the atomic numbers (which always stays on the same device), I did this by modifying OptimizedTorchANI:
In this PR, I also restructured SymmetryFunctions a bit. Instead of creating an implementation and storing it for the duration of the execution, now this class holds a map from devices to implementations, creating/fetching the necessary one according to where the positions are stored.
I did this because the positions (and only the positions) suddenly changing devices leaves us with an ambiguous decision. Do we:
In other words, what should OptimizedTorchANI do when this assertion fails?:
In the first case, we simply move positions back to the device when required.
In the second case, we must ensure every component can handle inputs with changing device.
OTOH, this makes me think: OpenMM suddenly changing the device of the positions without correctly informing NNPOps (perhaps by calling model.to(device)?) sounds to me like a bug in either OpenMM or OpenMM-Torch.
Finally I also applied the formatter.