Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parameter derivatives #142

Closed
wants to merge 32 commits into from

Conversation

RaulPPelaez
Copy link
Contributor

@RaulPPelaez RaulPPelaez commented May 15, 2024

Allows to get energy derivatives with respect to global parameters in TorchForce as in the following example:

class ForceWithParameters(pt.nn.Module):

    def __init__(self):
        super(ForceWithParameters, self).__init__()

    def forward(
        self, positions: Tensor, parameter1: Tensor, parameter2: Tensor
    ) -> Tensor:
        x2 = positions.pow(2).sum(dim=1)
        u_harmonic = ((parameter1 + parameter2**2) * x2).sum()
        return u_harmonic


def example():
    # Create a random cloud of particles.
    numParticles = 10
    system = mm.System()
    positions = np.random.rand(numParticles, 3)
    for _ in range(numParticles):
        system.addParticle(1.0)

    # Create a force
    pt_force = ForceWithParameters()
    model = pt.jit.script(pt_force)
    tforce = ot.TorchForce(model)
    # Add some parameters
    parameter1 = 1.0
    parameter2 = 1.0
    force.setOutputsForces(False)
    force.addGlobalParameter("parameter1", parameter1)
    force.addEnergyParameterDerivative("parameter1")
    force.addGlobalParameter("parameter2", parameter2)
    force.addEnergyParameterDerivative("parameter2")
    # Enable energy derivatives for the parameter
    system.addForce(force)
    # Compute the forces and energy.
    integ = mm.VerletIntegrator(1.0)
    platform = mm.Platform.getPlatformByName(platform)
    context = mm.Context(system, integ, platform)
    context.setPositions(positions)
    state = context.getState(
        getEnergy=True, getForces=True, getParameterDerivatives=True
    )
    # See if the energy and forces and the parameter derivative are correct.
    # The network defines a potential of the form E(r) = (parameter1 + parameter2**2)*|r|^2
    r2 = np.sum(positions * positions)
    expectedEnergy = (parameter1 + parameter2**2) * r2
    assert np.allclose(
        r2,
        state.getEnergyParameterDerivatives()["parameter1"],
    )
    assert np.allclose(
        2 * parameter2 * r2,
        state.getEnergyParameterDerivatives()["parameter2"],
    )

Closes #140
Closes #141

@RaulPPelaez
Copy link
Contributor Author

RaulPPelaez commented May 15, 2024

@peastman, I do not know how to connect CustomCVForce with the parameter derivative functionality in TorchForce.
I am writing this in the new test

    tforce = ot.TorchForce(model, {"useCUDAGraphs": "false"})
    # Add a parameter
    parameter1 = 1.0
    parameter2 = 1.0
    tforce.setOutputsForces(return_forces)
    tforce.addGlobalParameter("parameter1", parameter1)
    tforce.addGlobalParameter("parameter2", parameter2)
    # Enable energy derivatives for the parameters
    tforce.addEnergyParameterDerivative("parameter1")
    tforce.addEnergyParameterDerivative("parameter2")
    if use_cv_force:
        # Wrap TorchForce into CustomCVForce
        force = mm.CustomCVForce("force")
        force.addCollectiveVariable("force", tforce)
    else:
        force = tforce

For use_cv_force=True I get

>       assert np.allclose(                                                                                                                                                  
            r2,                                                                                                                                                              
            state.getEnergyParameterDerivatives()["parameter1"],                                                                                                             
        )                                                                                                                                                                    

TestParameterDerivatives.py:116: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <openmm.openmm.mapstringdouble; proxy of <Swig Object of type 'std::map< std::string,double > *' at 0x7f5dc4807ea0> >, key = 'parameter1'

    def __getitem__(self, key):
>       return _openmm.mapstringdouble___getitem__(self, key)
E       IndexError: key not found

The CUDA and OpenCL platforms yield the above error, but the Reference platform just returns 0, being silently incorrect.

@RaulPPelaez
Copy link
Contributor Author

There is a corner case I do not know how to handle.
It is possible that the model computes and provides the forces by calling autograd on the energy.
If the model does this by calling grad directly (instead of using energy.backwards) and sets retain_graph=False, then torch will not let you call backwards again.
In other words, this model is inherently incompatible with parameter derivatives:

class EnergyForceWithParameters(pt.nn.Module):

    def __init__(self):
        super(EnergyForceWithParameters, self).__init__()

    def forward(
        self, positions: Tensor, parameter1: Tensor, parameter2: Tensor
    ) -> Tuple[Tensor, Tensor]:
        positions.requires_grad_(True)
        x2 = positions.pow(2).sum(dim=1)
        u_harmonic = ((parameter1 + parameter2**2) * x2).sum()
        # This way of computing the forces forcefully leaves out the parameter derivatives
        grad_outputs: List[Optional[Tensor]] = [pt.ones_like(u_harmonic)]
        dy = pt.autograd.grad(
            [u_harmonic],
            [positions],
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=False,  # This cannot be False if parameter derivatives are needed
        )[0]
        assert dy is not None
        forces = -dy
        return u_harmonic, forces

TorchMD-Net does exactly this. Not sure what to do about it.

// The derivative is stored in the gradient of the parameter tensor.
double derivative = gradientTensors[i].item<double>();
auto name = energyParameterDerivatives[i];
derivs[name] = derivative;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I be summing here instead of overwritting?

@peastman
Copy link
Member

It looks like we both started working on this. I'm almost finished with implementing it.

@RaulPPelaez
Copy link
Contributor Author

welp... Take what you want from here if you find anything useful.

@peastman
Copy link
Member

Really sorry for the confusion! My implementation is at #141. If you can review it, that would be great.

@peastman
Copy link
Member

Sorry, that should have been #143.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support parameter derivatives How to use global parameter and PBC at the same time?
2 participants