Skip to content

Commit

Permalink
Refactor Transformations from closure into class (#2859)
Browse files Browse the repository at this point in the history
* Fixes #2860
* refactor transformations from closure into class (necessary so that a Universe
   with transformations can be pickled).
* change universe pickling to `__reduce__` (instead of `__setstate__`/`__getstate__`, which
   did not work with transformations)
* add test for pickling a Universe with transformations
* update docs for transformations
   - examples (written as class)
   - deprecate transformations as closures (still works but cannot be pickled due
      to limitations in Python's pickle module)
* update changelog
  • Loading branch information
yuxuanzhuang authored Sep 8, 2020
1 parent b2b7bcb commit 2c5e385
Show file tree
Hide file tree
Showing 10 changed files with 590 additions and 256 deletions.
4 changes: 3 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ Changes
* Sets the minimal RDKit version for CI to 2020.03.1 (Issue #2827, PR #2831)
* Removes deprecated waterdynamics.HydrogenBondLifetimes (PR #2842)
* Make NeighborSearch return empty atomgroup, residue, segments instead of list (Issue #2892, PR #2907)
* Updated Universe creation function signatures to named arguments (Issue #2921)
* Updated Universe creation function signatures to named arguments (Issue #2921)
* The transformation was changed from a function/closure to a class with
`__call__` (Issue #2860, PR #2859)

Deprecations

Expand Down
20 changes: 13 additions & 7 deletions package/MDAnalysis/core/universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,15 +666,21 @@ def __repr__(self):
return "<Universe with {n_atoms} atoms>".format(
n_atoms=len(self.atoms))

def __getstate__(self):
# Universe's two "legs" of topology and traj both serialise themselves
return self._topology, self._trajectory
@classmethod
def _unpickle_U(cls, top, traj):
"""Special method used by __reduce__ to deserialise a Universe"""
# top is a Topology obj at this point, but Universe can handle that.
u = cls(top)
u.trajectory = traj

def __setstate__(self, args):
self._topology = args[0]
_generate_from_topology(self)
return u

self._trajectory = args[1]
def __reduce__(self):
# __setstate__/__getstate__ will raise an error when Universe has a
# transformation (that has AtomGroup inside). Use __reduce__ instead.
# Universe's two "legs" of top and traj both serialise themselves.
return (self._unpickle_U, (self._topology,
self._trajectory))

# Properties
@property
Expand Down
101 changes: 85 additions & 16 deletions package/MDAnalysis/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,107 @@
#


"""\
"""
Trajectory transformations --- :mod:`MDAnalysis.transformations`
================================================================
The transformations submodule contains a collection of functions to modify the
trajectory. Coordinate transformations, such as PBC corrections and molecule fitting
are often required for some analyses and visualization, and the functions in this
module allow transformations to be applied on-the-fly.
These transformation functions can be called by the user for any given
timestep of the trajectory, added as a workflow using :meth:`add_transformations`
of the :mod:`~MDAnalysis.coordinates.base` module, or upon Universe creation using
The transformations submodule contains a collection of function-like classes to
modify the trajectory.
Coordinate transformations, such as PBC corrections and molecule fitting
are often required for some analyses and visualization, and the functions in
this module allow transformations to be applied on-the-fly.
A typical transformation class looks like this (note that we keep its name
lowercase because we will treat it as a function, thanks to the ``__call__``
method):
.. code-blocks:: python
class transformation(object):
def __init__(self, *args, **kwargs):
# do some things
# save needed args as attributes.
self.needed_var = args[0]
def __call__(self, ts):
# apply changes to the Timestep,
# or modify an AtomGroup and return Timestep
return ts
As a concrete example we will write a transformation that rotates a group of
atoms around the z-axis through the center of geometry by a fixed increment
for every time step. We will use
:meth:`MDAnalysis.core.groups.AtomGroup.rotateby`
and simply increment the rotation angle every time the
transformation is called ::
class spin_atoms(object):
def __init__(self, atoms, dphi):
# Rotate atoms by dphi degrees for every ts around the z axis
self.atoms = atoms
self.dphi = dphi
self.axis = np.array([0, 0, 1])
def __call__(self, ts):
phi = self.dphi * ts.frame
self.atoms.rotateby(phi, self.axis)
return ts
This transformation can be used as ::
u = mda.Universe(PSF, DCD)
u.trajectory.add_transformations(spin_atoms(u.select_atoms("protein"), 1.0))
Also see :mod:`MDAnalysis.transformations.translate` for a simple example.
These transformation functions can be called by the user for any given timestep
of the trajectory, added as a workflow using :meth:`add_transformations`
of the :mod:`~MDAnalysis.coordinates.base`, or upon Universe creation using
the keyword argument `transformations`. Note that in the two latter cases, the
workflow cannot be changed after being defined.
workflow cannot be changed after being defined. for example:
.. code-block:: python
In addition to the specific arguments that each transformation can take, they also
contain a wrapped function that takes a `Timestep` object as argument.
So, a transformation can be roughly defined as follows:
u = mda.Universe(GRO, XTC)
trans = transformation(args)
u.trajectory.add_transformations(trans)
# it is equivalent to applying this transforamtion to each Timestep by
ts = u.trajectory[0]
ts_trans = trans(ts)
Transformations can also be created as a closure/nested function.
In addition to the specific arguments that each transformation can take, they
also contain a wrapped function that takes a `Timestep` object as argument.
So, a closure-style transformation can be roughly defined as follows:
.. code-block:: python
def transformations(*args,**kwargs):
def transformation(*args,**kwargs):
# do some things
def wrapped(ts):
# apply changes to the Timestep object
# apply changes to the Timestep,
# or modify an AtomGroup and return Timestep
return ts
return wrapped
See `MDAnalysis.transformations.translate` for a simple example.
.. Note::
Although functions (closures) work as transformations, they are not used in
in MDAnalysis from release 2.0.0 onwards because they cannot be reliably
serialized and thus a :class:`Universe` with such transformations cannot be
used with common parallelization schemes (e.g., ones based on
:mod:`multiprocessing`).
For detailed descriptions about how to write a closure-style transformation,
please refer to MDAnalysis 1.x documentation.
.. versionchanged:: 2.0.0
Transformations should now be created as classes with a :meth:`__call__`
method instead of being written as a function/closure.
"""

from .translate import translate, center_in_box
Expand Down
171 changes: 102 additions & 69 deletions package/MDAnalysis/transformations/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,27 @@
Translate and/or rotates the coordinates of a given trajectory to align
a given AtomGroup to a reference structure.
.. autofunction:: fit_translation
.. autoclass:: fit_translation
.. autofunction:: fit_rot_trans
.. autoclass:: fit_rot_trans
"""
import numpy as np
from functools import partial

from ..analysis import align
from ..lib.util import get_weights
from ..lib.transformations import euler_from_matrix, euler_matrix


def fit_translation(ag, reference, plane=None, weights=None):

class fit_translation(object):
"""Translates a given AtomGroup so that its center of geometry/mass matches
the respective center of the given reference. A plane can be given by the
user using the option `plane`, and will result in the removal of
the translation motions of the AtomGroup over that particular plane.
Example
-------
Removing the translations of a given AtomGroup `ag` on the XY plane by fitting
its center of mass to the center of mass of a reference `ref`:
Removing the translations of a given AtomGroup `ag` on the XY plane by
fitting its center of mass to the center of mass of a reference `ref`:
.. code-block:: python
Expand All @@ -67,11 +64,12 @@ def fit_translation(ag, reference, plane=None, weights=None):
:class:`~MDAnalysis.core.groups.AtomGroup` or a whole
:class:`~MDAnalysis.core.universe.Universe`
reference : Universe or AtomGroup
reference structure, a :class:`~MDAnalysis.core.groups.AtomGroup` or a whole
:class:`~MDAnalysis.core.universe.Universe`
reference structure, a :class:`~MDAnalysis.core.groups.AtomGroup` or a
whole :class:`~MDAnalysis.core.universe.Universe`
plane: str, optional
used to define the plane on which the translations will be removed. Defined as a
string of the plane. Suported planes are yz, xz and xy planes.
used to define the plane on which the translations will be removed.
Defined as a string of the plane.
Suported planes are yz, xz and xy planes.
weights : {"mass", ``None``} or array_like, optional
choose weights. With ``"mass"`` uses masses as weights; with ``None``
weigh each atom equally. If a float array of the same length as
Expand All @@ -81,39 +79,56 @@ def fit_translation(ag, reference, plane=None, weights=None):
Returns
-------
MDAnalysis.coordinates.base.Timestep
"""
if plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
.. versionchanged:: 2.0.0
The transformation was changed from a function/closure to a class
with ``__call__``.
"""
def __init__(self, ag, reference, plane=None, weights=None):
self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights

if self.plane is not None:
axes = {'yz': 0, 'xz': 1, 'xy': 2}
try:
self.plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') \
from None
try:
plane = axes[plane]
except (TypeError, KeyError):
raise ValueError(f'{plane} is not a valid plane') from None
try:
if ag.atoms.n_residues != reference.atoms.n_residues:
errmsg = f"{ag} and {reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{ag} or {reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(reference.atoms, ag.atoms)
weights = align.get_weights(ref.atoms, weights=weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

def wrapped(ts):
mobile_com = np.asarray(mobile.atoms.center(weights), np.float32)
vector = ref_com - mobile_com
if plane is not None:
vector[plane] = 0
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = (
f"{self.ag} and {self.reference} have mismatched"
f"number of residues"
)

raise ValueError(errmsg)
except AttributeError:
errmsg = (
f"{self.ag} or {self.reference} is not valid"
f"Universe/AtomGroup"
)
raise AttributeError(errmsg) from None
self.ref, self.mobile = align.get_matching_atoms(self.reference.atoms,
self.ag.atoms)
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)

def __call__(self, ts):
mobile_com = np.asarray(self.mobile.atoms.center(self.weights),
np.float32)
vector = self.ref_com - mobile_com
if self.plane is not None:
vector[self.plane] = 0
ts.positions += vector

return ts

return wrapped


def fit_rot_trans(ag, reference, plane=None, weights=None):
class fit_rot_trans(object):
"""Perform a spatial superposition by minimizing the RMSD.
Spatially align the group of atoms `ag` to `reference` by doing a RMSD
Expand Down Expand Up @@ -160,41 +175,59 @@ def fit_rot_trans(ag, reference, plane=None, weights=None):
-------
MDAnalysis.coordinates.base.Timestep
"""
if plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
def __init__(self, ag, reference, plane=None, weights=None):
self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights

if self.plane is not None:
axes = {'yz': 0, 'xz': 1, 'xy': 2}
try:
self.plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') \
from None
try:
plane = axes[plane]
except (TypeError, KeyError):
raise ValueError(f'{plane} is not a valid plane') from None
try:
if ag.atoms.n_residues != reference.atoms.n_residues:
errmsg = f"{ag} and {reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{ag} or {reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(reference.atoms, ag.atoms)
weights = align.get_weights(ref.atoms, weights=weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

def wrapped(ts):
mobile_com = mobile.atoms.center(weights)
mobile_coordinates = mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates, ref_coordinates, weights=weights)
vector = ref_com
if plane is not None:
matrix = np.r_[rotation, np.zeros(3).reshape(1,3)]
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = (
f"{self.ag} and {self.reference} have mismatched "
f"number of residues"
)
raise ValueError(errmsg)
except AttributeError:
errmsg = (
f"{self.ag} or {self.reference} is not valid "
f"Universe/AtomGroup"
)
raise AttributeError(errmsg) from None
self.ref, self.mobile = align.get_matching_atoms(self.reference.atoms,
self.ag.atoms)
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)
self.ref_coordinates = self.ref.atoms.positions - self.ref_com

def __call__(self, ts):
mobile_com = self.mobile.atoms.center(self.weights)
mobile_coordinates = self.mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates,
self.ref_coordinates,
weights=self.weights)
vector = self.ref_com
if self.plane is not None:
matrix = np.r_[rotation, np.zeros(3).reshape(1, 3)]
matrix = np.c_[matrix, np.zeros(4)]
euler_angs = np.asarray(euler_from_matrix(matrix, axes='sxyz'), np.float32)
euler_angs = np.asarray(euler_from_matrix(matrix, axes='sxyz'),
np.float32)
for i in range(0, euler_angs.size):
euler_angs[i] = ( euler_angs[plane] if i == plane else 0)
rotation = euler_matrix(euler_angs[0], euler_angs[1], euler_angs[2], axes='sxyz')[:3, :3]
vector[plane] = mobile_com[plane]
euler_angs[i] = (euler_angs[self.plane] if i == self.plane
else 0)
rotation = euler_matrix(euler_angs[0],
euler_angs[1],
euler_angs[2],
axes='sxyz')[:3, :3]
vector[self.plane] = mobile_com[self.plane]
ts.positions = ts.positions - mobile_com
ts.positions = np.dot(ts.positions, rotation.T)
ts.positions = ts.positions + vector

return ts

return wrapped
Loading

0 comments on commit 2c5e385

Please sign in to comment.