diff --git a/modules/pmi/examples/atomistic.py b/modules/pmi/examples/atomistic.py index e10c7d6e54..6b102a0296 100644 --- a/modules/pmi/examples/atomistic.py +++ b/modules/pmi/examples/atomistic.py @@ -3,6 +3,7 @@ with a secondary structure elastic network to speed things up. """ +from __future__ import print_function import IMP import RMF import IMP.atom @@ -15,6 +16,10 @@ import sys IMP.setup_from_argv(sys.argv, "Simulation of an atomic system") +if IMP.get_is_quick_test(): + print("This example is too slow to test in debug mode - run without") + print("internal tests enabled, or without the --run-quick-test flag") + sys.exit(0) # Setup System and add a State mdl = IMP.Model() diff --git a/modules/pmi/examples/automatic.py b/modules/pmi/examples/automatic.py index 31b24e2076..d392fb3b4a 100644 --- a/modules/pmi/examples/automatic.py +++ b/modules/pmi/examples/automatic.py @@ -2,6 +2,7 @@ """This script shows how to use the BuildSystem macro to construct large systems with minimal code """ +from __future__ import print_function import IMP import RMF import IMP.atom @@ -16,6 +17,10 @@ import sys IMP.setup_from_argv(sys.argv, "Automatic setup of a large system") +if IMP.get_is_quick_test(): + print("This example is too slow to test in debug mode - run without") + print("internal tests enabled, or without the --run-quick-test flag") + sys.exit(0) # This is the topology table format. # It allows you to create many components in a simple way diff --git a/modules/pmi/examples/multiscale.py b/modules/pmi/examples/multiscale.py index 84d106bf27..6dd27af78b 100644 --- a/modules/pmi/examples/multiscale.py +++ b/modules/pmi/examples/multiscale.py @@ -2,6 +2,7 @@ """This script shows how to represent a system at multiple scales and do basic sampling. """ +from __future__ import print_function import IMP import RMF import IMP.atom @@ -15,6 +16,10 @@ import sys IMP.setup_from_argv(sys.argv, "Representation at multiple scales") +if IMP.get_is_quick_test(): + print("This example is too slow to test in debug mode - run without") + print("internal tests enabled, or without the --run-quick-test flag") + sys.exit(0) ###################### SYSTEM SETUP ##################### # Read sequences etc diff --git a/modules/pmi/pyext/src/restraints/__init__.py b/modules/pmi/pyext/src/restraints/__init__.py index 9889e4dfd4..be016e5306 100644 --- a/modules/pmi/pyext/src/restraints/__init__.py +++ b/modules/pmi/pyext/src/restraints/__init__.py @@ -16,42 +16,71 @@ class RestraintBase(object): """Base class for PMI restraints, which wrap `IMP.Restraint`(s).""" - def __init__(self, m, rname=None): + def __init__(self, m, name=None, label=None, weight=1.): """Constructor. @param m The model object - @param rname The name of the primary restraint set that is wrapped. + @param name The name of the primary restraint set that is wrapped. + This is used for outputs and particle/restraint names + and should be set by the child class. + @param label A unique label to be used in outputs and + particle/restraint names. + @param weight The weight to apply to all internal restraints. """ self.m = m - self.label = None - self.weight = 1. - if rname is None: - rname = self.__class__.__name__ - self.rs = IMP.RestraintSet(self.m, rname) - self.restraint_sets = [self.rs] - - def set_weight(self, weight): - """Set the weight of the restraint. - @param weight Restraint weight - """ + self.restraint_sets = [] + self._label_is_set = False self.weight = weight - self.rs.set_weight(weight) + self._label = None + self._label_suffix = "" + self.set_label(label) + + if not name: + self.name = self.__class__.__name__ + else: + self.name = str(name) + + self.rs = self._create_restraint_set(name=None) def set_label(self, label): - """Set the label used in outputs. + """Set the unique label used in outputs and particle/restraint names. @param label Label """ - self.label = label + if self._label_is_set: + raise ValueError("Label has already been set.") + if not label: + self._label = "" + self._label_suffix = "" + else: + self._label = str(label) + self._label_suffix = "_" + self._label + self._label_is_set = True + + @property + def label(self): + return self._label + + def set_weight(self, weight): + """Set the weight to apply to all internal restraints. + @param weight Weight + """ + self.weight = weight + for rs in self.restraint_sets: + rs.set_weight(self.weight) def add_to_model(self): """Add the restraint to the model.""" - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) + self._label_is_set = True + for rs in self.restraint_sets: + IMP.pmi.tools.add_restraint_to_model(self.m, rs) def evaluate(self): """Evaluate the score of the restraint.""" + self._label_is_set = True return self.weight * self.rs.unprotected_evaluate(None) def get_restraint_set(self): """Get the primary restraint set.""" + self._label_is_set = True return self.rs def get_restraint(self): @@ -60,39 +89,98 @@ def get_restraint(self): def get_restraint_for_rmf(self): """Get the restraint for visualization in an RMF file.""" + self._label_is_set = True return self.rs def get_particles_to_sample(self): """Get any created particles which should be sampled.""" + self._label_is_set = True return {} def get_output(self): """Get outputs to write to stat files.""" - self.m.update() output = {} + self.m.update() score = self.evaluate() output["_TotalScore"] = str(score) - suffix = "_Score" - if self.label is not None: - suffix += "_" + str(self.label) - + suffix = "_Score" + self._label_suffix for rs in self.restraint_sets: out_name = rs.get_name() + suffix output[out_name] = str( self.weight * rs.unprotected_evaluate(None)) + return output + + def _create_restraint_set(self, name=None): + """Create ``IMP.RestraintSet``.""" + if not name: + name = self.name + else: + name = self.name + "_" + str(name) + rs = IMP.RestraintSet(self.m, name) + rs.set_weight(self.weight) + self.restraint_sets.append(rs) + return rs + + +class _RestraintNuisanceMixin(object): + """Mix-in to add nuisance particle creation functionality to restraint. + + This class must only be inherited if also inheriting + IMP.pmi.restraints.RestraintBase. + """ + + def __init__(self, *args, **kwargs): + super(_RestraintNuisanceMixin, self).__init__(*args, **kwargs) + self.sampled_nuisances = {} + self.nuisances = {} + + def _create_nuisance(self, init_val, min_val, max_val, max_trans, name, + is_sampled=False): + """Create nuisance particle. + @param init_val Initial value of nuisance + @param min_val Minimum value of nuisance + @param max_val Maximum value of nuisance + @param max_trans Maximum move to apply to nuisance + @param is_sampled Nuisance is a sampled particle + \see IMP.pmi.tools.SetupNuisance + """ + nuis = IMP.pmi.tools.SetupNuisance( + self.m, init_val, min_val, max_val, + isoptimized=is_sampled).get_particle() + nuis_name = self.name + "_" + name + nuis.set_name(nuis_name) + self.nuisances[nuis_name] = nuis + if is_sampled: + self.sampled_nuisances[nuis_name] = (nuis, max_trans) + return nuis + + def get_particles_to_sample(self): + """Get any created particles which should be sampled.""" + ps = super(_RestraintNuisanceMixin, self).get_particles_to_sample() + for name, (nuis, max_trans) in self.sampled_nuisances.iteritems(): + ps["Nuisances_" + name + self._label_suffix] = ([nuis], max_trans) + return ps + + def get_output(self): + """Get outputs to write to stat files.""" + output = super(_RestraintNuisanceMixin, self).get_output() + for nuis_name, nuis in self.nuisances.iteritems(): + output[nuis_name + self._label_suffix] = str(nuis.get_scale()) return output class _NuisancesBase(object): - ''' This base class is used to provide nuisance setup and interface - for the ISD cross-link restraints ''' - sigma_dictionary={} - psi_dictionary={} + + """This base class is used to provide nuisance setup and interface + for the ISD cross-link restraints""" + + sigma_dictionary = {} + psi_dictionary = {} def create_length(self): - ''' a nuisance on the length of the cross-link ''' + """Create a nuisance on the length of the cross-link.""" lengthinit = 10.0 self.lengthissampled = True lengthminnuis = 0.0000001 @@ -101,9 +189,9 @@ def create_length(self): lengthmax = 30.0 lengthtrans = 0.2 length = IMP.pmi.tools.SetupNuisance(self.m, lengthinit, - lengthminnuis, - lengthmaxnuis, - lengthissampled).get_particle() + lengthminnuis, lengthmaxnuis, + self.lengthissampled + ).get_particle() self.rslen.add_restraint( IMP.isd.UniformPrior( self.m, @@ -113,8 +201,8 @@ def create_length(self): lengthmin)) def create_sigma(self, resolution): - ''' a nuisance on the structural uncertainty ''' - if isinstance(resolution,str): + """Create a nuisance on the structural uncertainty.""" + if isinstance(resolution, str): sigmainit = 2.0 else: sigmainit = resolution + 2.0 @@ -124,8 +212,9 @@ def create_sigma(self, resolution): sigmamin = 0.01 sigmamax = 100.0 sigmatrans = 0.5 - sigma = IMP.pmi.tools.SetupNuisance(self.m, sigmainit, - sigmaminnuis, sigmamaxnuis, self.sigmaissampled).get_particle() + sigma = IMP.pmi.tools.SetupNuisance(self.m, sigmainit, sigmaminnuis, + sigmamaxnuis, self.sigmaissampled + ).get_particle() self.sigma_dictionary[resolution] = ( sigma, sigmatrans, @@ -140,13 +229,14 @@ def create_sigma(self, resolution): # self.rssig.add_restraint(IMP.isd.JeffreysRestraint(self.sigma)) def get_sigma(self, resolution): - if not resolution in self.sigma_dictionary: + """Get the nuisance on structural uncertainty.""" + if resolution not in self.sigma_dictionary: self.create_sigma(resolution) return self.sigma_dictionary[resolution] def create_psi(self, value): - ''' a nuisance on the inconsistency ''' - if isinstance(value,str): + """Create a nuisance on the inconsistency.""" + if isinstance(value, str): psiinit = 0.5 else: psiinit = value @@ -157,8 +247,8 @@ def create_psi(self, value): psimax = 0.49 psitrans = 0.1 psi = IMP.pmi.tools.SetupNuisance(self.m, psiinit, - psiminnuis, psimaxnuis, - self.psiissampled).get_particle() + psiminnuis, psimaxnuis, + self.psiissampled).get_particle() self.psi_dictionary[value] = ( psi, psitrans, @@ -173,6 +263,7 @@ def create_psi(self, value): self.rspsi.add_restraint(IMP.isd.JeffreysRestraint(self.m, psi)) def get_psi(self, value): - if not value in self.psi_dictionary: + """Get the nuisance on the inconsistency.""" + if value not in self.psi_dictionary: self.create_psi(value) return self.psi_dictionary[value] diff --git a/modules/pmi/pyext/src/restraints/basic.py b/modules/pmi/pyext/src/restraints/basic.py index 7f15a3e1d5..bc6930d13d 100644 --- a/modules/pmi/pyext/src/restraints/basic.py +++ b/modules/pmi/pyext/src/restraints/basic.py @@ -9,8 +9,12 @@ import IMP.atom import IMP.container import IMP.pmi.tools +import IMP.pmi.restraints -class ExternalBarrier(object): + +class ExternalBarrier(IMP.pmi.restraints.RestraintBase): + + """Restraint to keep all structures inside sphere.""" def __init__(self, representation=None, @@ -18,40 +22,46 @@ def __init__(self, hierarchies=None, resolution=10, weight=1.0, - center=None): - """Setup external barrier to keep all your structures inside sphere + center=None, + label=None): + """Setup external barrier restraint. @param representation DEPRECATED - @param center - Center of the external barrier restraint (IMP.algebra.Vector3D object) @param radius Size of external barrier - @param hierarchies Can be one of the following inputs: - IMP Hierarchy, PMI System/State/Molecule/TempResidue, or a list/set of them + @param hierarchies Can be one of the following inputs: IMP Hierarchy, + PMI System/State/Molecule/TempResidue, or a list/set of them @param resolution Select which resolutions to act upon + @param weight Weight of restraint + @param center Center of the external barrier restraint + (IMP.algebra.Vector3D object) + @param label A unique label to be used in outputs and + particle/restraint names. """ - self.radius = radius - self.label = "None" - self.weight = weight - if representation: - self.m = representation.prot.get_model() + m = representation.prot.get_model() particles = IMP.pmi.tools.select( representation, resolution=resolution, hierarchies=hierarchies) elif hierarchies: - hiers = IMP.pmi.tools.input_adaptor(hierarchies,resolution,flatten=True) - self.m = hiers[0].get_model() + hiers = IMP.pmi.tools.input_adaptor(hierarchies, resolution, + flatten=True) + m = hiers[0].get_model() particles = [h.get_particle() for h in hiers] else: - raise Exception("ExternalBarrier: must pass representation or hierarchies") + raise Exception("%s: must pass representation or hierarchies" % ( + self.name)) - self.rs = IMP.RestraintSet(self.m, 'barrier') + super(ExternalBarrier, self).__init__(m, label=label, weight=weight) + self.radius = radius if center is None: c3 = IMP.algebra.Vector3D(0, 0, 0) elif type(center) is IMP.algebra.Vector3D: c3 = center else: - raise Exception("ExternalBarrier: @param center must be an algebra::Vector3D object") + raise Exception( + "%s: @param center must be an IMP.algebra.Vector3D object" % ( + self.name)) ub3 = IMP.core.HarmonicUpperBound(radius, 10.0) ss3 = IMP.core.DistanceToSingletonScore(ub3, c3) @@ -60,35 +70,12 @@ def __init__(self, lsc.add(particles) r3 = IMP.container.SingletonsRestraint(ss3, lsc) self.rs.add_restraint(r3) - self.set_weight(self.weight) - - def set_label(self, label): - self.label = label - - def add_to_model(self): - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) - - def get_restraint(self): - return self.rs - - def get_output(self): - self.m.update() - output = {} - score = self.evaluate() - output["_TotalScore"] = str(score) - output["ExternalBarrier_" + self.label] = str(score) - return output - def set_weight(self, weight): - self.weight = weight - self.rs.set_weight(weight) - def evaluate(self): - return self.weight * self.rs.unprotected_evaluate(None) +class DistanceRestraint(IMP.pmi.restraints.RestraintBase): - -class DistanceRestraint(object): """A simple distance restraint""" + def __init__(self, representation=None, tuple_selection1=None, @@ -97,27 +84,34 @@ def __init__(self, distancemax=100, resolution=1.0, kappa=1.0, - root_hier = None): + root_hier=None, + label=None, + weight=1.): """Setup distance restraint. @param representation DEPRECATED - @param tuple_selection1 (resnum,resnum,molecule name, copy number (=0)) - @param tuple_selection2 (resnum,resnum,molecule name, copy number (=0)) + @param tuple_selection1 (resnum, resnum, molecule name, copy + number (=0)) + @param tuple_selection2 (resnum, resnum, molecule name, copy + number (=0)) @param distancemin The minimum dist @param distancemax The maximum dist @param resolution For selecting particles @param kappa The harmonic parameter - @param root_hier The hierarchy to select from (use this instead of representation) - \note Pass the same resnum twice to each tuple_selection. Optionally add a copy number (PMI2 only) + @param root_hier The hierarchy to select from (use this instead of + representation) + @param label A unique label to be used in outputs and + particle/restraint names + @param weight Weight of restraint + \note Pass the same resnum twice to each tuple_selection. Optionally + add a copy number (PMI2 only) """ - self.weight=1.0 - self.label="None" if tuple_selection1 is None or tuple_selection2 is None: raise Exception("You must pass tuple_selection1/2") ts1 = IMP.core.HarmonicUpperBound(distancemax, kappa) ts2 = IMP.core.HarmonicLowerBound(distancemin, kappa) if representation and not root_hier: - self.m = representation.prot.get_model() + m = representation.prot.get_model() particles1 = IMP.pmi.tools.select(representation, resolution=resolution, name=tuple_selection1[2], @@ -127,12 +121,12 @@ def __init__(self, name=tuple_selection2[2], residue=tuple_selection2[0]) elif root_hier and not representation: - self.m = root_hier.get_model() + m = root_hier.get_model() copy_num1 = 0 - if len(tuple_selection1)>3: + if len(tuple_selection1) > 3: copy_num1 = tuple_selection1[3] copy_num2 = 0 - if len(tuple_selection2)>3: + if len(tuple_selection2) > 3: copy_num2 = tuple_selection2[3] sel1 = IMP.atom.Selection(root_hier, @@ -150,10 +144,12 @@ def __init__(self, else: raise Exception("Pass representation or root_hier, not both") - self.rs = IMP.RestraintSet(self.m, 'distance') + super(DistanceRestraint, self).__init__(m, label=label, weight=weight) + print(self.name) print("Created distance restraint between " - "%s and %s" % (particles1[0].get_name(),particles2[0].get_name())) + "%s and %s" % (particles1[0].get_name(), + particles2[0].get_name())) if len(particles1) > 1 or len(particles2) > 1: raise ValueError("more than one particle selected") @@ -167,33 +163,6 @@ def __init__(self, particles1[0], particles2[0])) - def set_weight(self,weight): - self.weight = weight - self.rs.set_weight(weight) - - def set_label(self, label): - self.label = label - - def add_to_model(self): - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) - - def get_restraint(self): - return self.rs - - def get_restraint_for_rmf(self): - return self.rs - - def get_output(self): - self.m.update() - output = {} - score = self.weight * self.rs.unprotected_evaluate(None) - output["_TotalScore"] = str(score) - output["DistanceRestraint_" + self.label] = str(score) - return output - - def evaluate(self): - return self.weight * self.rs.unprotected_evaluate(None) - class TorqueRestraint(IMP.Restraint): import math @@ -364,8 +333,10 @@ def do_get_inputs(self): return self.particle_list -class DistanceToPointRestraint(object): - """DistanceToPointRestraint for anchoring a particle to a specific coordinate""" +class DistanceToPointRestraint(IMP.pmi.restraints.RestraintBase): + + """Restraint for anchoring a particle to a specific coordinate.""" + def __init__(self, representation=None, tuple_selection=None, @@ -374,33 +345,38 @@ def __init__(self, kappa=10.0, resolution=1.0, weight=1.0, - root_hier = None): + root_hier=None, + label=None): """Setup distance restraint. @param representation DEPRECATED - @param tuple_selection (resnum,resnum,molecule name, copy number (=0)) - @param anchor_point - Center of the DistanceToPointRestraint (IMP.algebra.Vector3D object) - @param radius Size of the tolerance length in DistanceToPointRestraint + @param tuple_selection (resnum, resnum, molecule name, + copy number (=0)) + @param anchor_point Point to which to restrain particle + (IMP.algebra.Vector3D object) + @param radius Size of the tolerance length + @param kappa The harmonic parameter @param resolution For selecting a particle - @param root_hier The hierarchy to select from (use this instead of representation) - \note Pass the same resnum twice to each tuple_selection. Optionally add a copy number (PMI2 only) + @param weight Weight of restraint + @param root_hier The hierarchy to select from (use this instead of + representation) + @param label A unique label to be used in outputs and + particle/restraint names + \note Pass the same resnum twice to each tuple_selection. Optionally + add a copy number (PMI2 only) """ - self.radius = radius - self.label = "None" - self.weight = weight - if tuple_selection is None: raise Exception("You must pass a tuple_selection") if representation and not root_hier: - self.m = representation.prot.get_model() + m = representation.prot.get_model() ps = IMP.pmi.tools.select(representation, - resolution=resolution, - name=tuple_selection[2], - residue=tuple_selection[0]) + resolution=resolution, + name=tuple_selection[2], + residue=tuple_selection[0]) elif root_hier and not representation: - self.m = root_hier.get_model() + m = root_hier.get_model() copy_num1 = 0 - if len(tuple_selection)>3: + if len(tuple_selection) > 3: copy_num1 = tuple_selection[3] sel1 = IMP.atom.Selection(root_hier, @@ -410,18 +386,25 @@ def __init__(self, copy_index=copy_num1) ps = sel1.get_selected_particles() else: - raise Exception("DistanceToPointRestraint: Pass representation or root_hier, not both") + raise Exception("%s: Pass representation or root_hier, not both" % + self.name) if len(ps) > 1: - raise ValueError("DistanceToPointRestraint: more than one particle selected") + raise ValueError("%s: more than one particle selected" % + self.name) + + super(DistanceToPointRestraint, self).__init__(m, label=label, + weight=weight) + self.radius = radius - self.rs = IMP.RestraintSet(self.m, 'distance_to_point') ub3 = IMP.core.HarmonicUpperBound(self.radius, kappa) if anchor_point is None: c3 = IMP.algebra.Vector3D(0, 0, 0) elif type(anchor_point) is IMP.algebra.Vector3D: c3 = anchor_point else: - raise Exception("DistanceToPointRestraint: @param anchor_point must be an algebra::Vector3D object") + raise Exception( + "%s: @param anchor_point must be an algebra.Vector3D object" % + self.name) ss3 = IMP.core.DistanceToSingletonScore(ub3, c3) lsc = IMP.container.ListSingletonContainer(self.m) @@ -429,34 +412,6 @@ def __init__(self, r3 = IMP.container.SingletonsRestraint(ss3, lsc) self.rs.add_restraint(r3) - self.set_weight(self.weight) - - print("\nDistanceToPointRestraint: Created distance_to_point_restraint between " - "%s and %s" % (ps[0].get_name(), c3)) - - def set_weight(self,weight): - self.weight = weight - self.rs.set_weight(weight) - - def set_label(self, label): - self.label = label - - def add_to_model(self): - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) - - def get_restraint(self): - return self.rs - - def get_restraint_for_rmf(self): - return self.rs - - def get_output(self): - self.m.update() - output = {} - score = self.evaluate() - output["_TotalScore"] = str(score) - output["DistanceToPointRestraint_" + self.label] = str(score) - return output - def evaluate(self): - return self.weight * self.rs.unprotected_evaluate(None) + print("\n%s: Created distance_to_point_restraint between " + "%s and %s" % (self.name, ps[0].get_name(), c3)) diff --git a/modules/pmi/pyext/src/restraints/crosslinking.py b/modules/pmi/pyext/src/restraints/crosslinking.py index 5902a5a573..80cd904020 100644 --- a/modules/pmi/pyext/src/restraints/crosslinking.py +++ b/modules/pmi/pyext/src/restraints/crosslinking.py @@ -13,13 +13,14 @@ import IMP.pmi.metadata import IMP.pmi.output import IMP.pmi.io.crosslink +import IMP.pmi.restraints from math import log from collections import defaultdict import itertools import operator import os -class CrossLinkingMassSpectrometryRestraint(object): +class CrossLinkingMassSpectrometryRestraint(IMP.pmi.restraints.RestraintBase): """Setup cross-link distance restraints from mass spectrometry data. The noise in the data and the structural uncertainty of cross-linked amino-acids is inferred using Bayes theory of probability @@ -31,9 +32,10 @@ def __init__(self, representation=None, length=10.0, resolution=None, slope=0.02, - label="None", + label=None, filelabel="None", - attributes_for_label=None): + attributes_for_label=None, + weight=1.): """Constructor. @param representation DEPRECATED The IMP.pmi.representation.Representation object that contain the molecular system @@ -51,6 +53,7 @@ def __init__(self, representation=None, @param filelabel automatically generated file containing missing/included/excluded cross-links will be labeled using this text @param attributes_for_label + @param weight Weight of restraint """ use_pmi2 = True @@ -60,12 +63,15 @@ def __init__(self, representation=None, representations = [representation] else: representations = representation - self.m = representations[0].prot.get_model() + m = representations[0].prot.get_model() elif root_hier is not None: - self.m = root_hier.get_model() + m = root_hier.get_model() else: raise Exception("You must pass either representation or root_hier") + super(CrossLinkingMassSpectrometryRestraint, self).__init__( + m, weight=weight, label=label) + if CrossLinkDataBase is None: raise Exception("You must pass a CrossLinkDataBase") if not isinstance(CrossLinkDataBase,IMP.pmi.io.crosslink.CrossLinkDataBase): @@ -79,18 +85,16 @@ def __init__(self, representation=None, exdb = open("excluded." + filelabel + ".xl.db", "w") midb = open("missing." + filelabel + ".xl.db", "w") - - self.rs = IMP.RestraintSet(self.m, 'likelihood') - self.rspsi = IMP.RestraintSet(self.m, 'prior_psi') - self.rssig = IMP.RestraintSet(self.m, 'prior_sigmas') - self.rslin = IMP.RestraintSet(self.m, 'linear_dummy_restraints') + self.rs.set_name(self.rs.get_name() + "_Data") + self.rspsi = self._create_restraint_set("PriorPsi") + self.rssig = self._create_restraint_set("PriorSig") + self.rslin = self._create_restraint_set("Linear") # dummy linear restraint used for Chimera display self.linear = IMP.core.Linear(0, 0.0) self.linear.set_slope(0.0) dps2 = IMP.core.DistancePairScore(self.linear) - self.label = label self.psi_is_sampled = True self.sigma_is_sampled = True self.psi_dictionary={} @@ -291,13 +295,6 @@ def __init__(self, representation=None, lw = IMP.isd.LogWrapper(restraints,1.0) self.rs.add_restraint(lw) - def add_to_model(self): - """ Add the restraint to the model so that it is evaluated """ - IMP.pmi.tools.add_restraint_to_model(self.m, self.rs) - IMP.pmi.tools.add_restraint_to_model(self.m, self.rspsi) - IMP.pmi.tools.add_restraint_to_model(self.m, self.rssig) - IMP.pmi.tools.add_restraint_to_model(self.m, self.rslin) - def get_hierarchies(self): """ get the hierarchy """ return self.prot @@ -306,14 +303,6 @@ def get_restraint_sets(self): """ get the restraint set """ return self.rs - def get_restraint(self): - """ get the restraint set (redundant with get_restraint_sets)""" - return self.rs - - def get_restraint_for_rmf(self): - """ get the dummy restraints to be displayed in the rmf file """ - return self.rslin - def get_restraints(self): """ get the restraints in a list """ rlist = [] @@ -321,6 +310,10 @@ def get_restraints(self): rlist.append(IMP.core.PairRestraint.get_from(r)) return rlist + def get_restraint_for_rmf(self): + """ get the dummy restraints to be displayed in the rmf file """ + return self.rslin + def get_particle_pairs(self): """ Get a list of tuples containing the particle pairs """ ppairs = [] @@ -401,24 +394,10 @@ def create_psi(self, name): self.rspsi.add_restraint(IMP.isd.JeffreysRestraint(self.m, psi)) return psi - def set_label(self, s): - """ Set the restraint output label """ - self.label=s - def get_output(self): """ Get the output of the restraint to be used by the IMP.pmi.output object""" - self.m.update() + output = super(CrossLinkingMassSpectrometryRestraint, self).get_output() - output = {} - score = self.rs.unprotected_evaluate(None) - output["_TotalScore"] = str(score) - output["CrossLinkingMassSpectrometryRestraint_Data_Score_" + self.label] = str(score) - output["CrossLinkingMassSpectrometryRestraint_PriorSig_Score_" + - self.label] = self.rssig.unprotected_evaluate(None) - output["CrossLinkingMassSpectrometryRestraint_PriorPsi_Score_" + - self.label] = self.rspsi.unprotected_evaluate(None) - output["CrossLinkingMassSpectrometryRestraint_Linear_Score_" + - self.label] = self.rslin.unprotected_evaluate(None) for xl in self.xl_list: xl_label=xl["ShortLabel"] @@ -436,11 +415,13 @@ def get_output(self): for psiname in self.psi_dictionary: output["CrossLinkingMassSpectrometryRestraint_Psi_" + - str(psiname) + "_" + self.label] = str(self.psi_dictionary[psiname][0].get_scale()) + str(psiname) + self._label_suffix] = str( + self.psi_dictionary[psiname][0].get_scale()) for sigmaname in self.sigma_dictionary: output["CrossLinkingMassSpectrometryRestraint_Sigma_" + - str(sigmaname) + "_" + self.label] = str(self.sigma_dictionary[sigmaname][0].get_scale()) + str(sigmaname) + self._label_suffix] = str( + self.sigma_dictionary[sigmaname][0].get_scale()) return output @@ -450,20 +431,21 @@ def get_particles_to_sample(self): ps = {} if self.sigma_is_sampled: for sigmaname in self.sigma_dictionary: - ps["Nuisances_CrossLinkingMassSpectrometryRestraint_Sigma_" + str(sigmaname) + "_" + self.label] =\ + ps["Nuisances_CrossLinkingMassSpectrometryRestraint_Sigma_" + + str(sigmaname) + self._label_suffix] =\ ([self.sigma_dictionary[sigmaname][0]], self.sigma_dictionary[sigmaname][1]) if self.psi_is_sampled: for psiname in self.psi_dictionary: ps["Nuisances_CrossLinkingMassSpectrometryRestraint_Psi_" + - str(psiname) + "_" + self.label] =\ + str(psiname) + self._label_suffix] =\ ([self.psi_dictionary[psiname][0]], self.psi_dictionary[psiname][1]) return ps -class AtomicCrossLinkMSRestraint(object): +class AtomicCrossLinkMSRestraint(IMP.pmi.restraints.RestraintBase): """Setup cross-link distance restraints at atomic level The "atomic" aspect is that it models the particle uncertainty with a Gaussian The noise in the data and the structural uncertainty of cross-linked amino-acids @@ -477,12 +459,13 @@ def __init__(self, length=10.0, slope=0.01, nstates=None, - label='', + label=None, nuisances_are_optimized=True, sigma_init=5.0, psi_init = 0.01, one_psi=True, - filelabel=None): + filelabel=None, + weight=1.): """Constructor. Automatically creates one "sigma" per crosslinked residue and one "psis" per pair. Other nuisance options are available. @@ -500,14 +483,17 @@ def __init__(self, @param one_psi Use a single psi for all restraints (if False, creates one per XL) @param filelabel automatically generated file containing missing/included/excluded cross-links will be labeled using this text + @param weight Weight of restraint + """ # basic params self.root = root_hier - self.mdl = self.root.get_model() + rname = "AtomicXLRestraint" + super(AtomicCrossLinkMSRestraint, self).__init__( + self.root.get_model(), name="AtomicXLRestraint", label=label, + weight=weight) self.xldb = xldb - self.weight = 1.0 - self.label = label self.length = length self.sigma_is_sampled = nuisances_are_optimized self.psi_is_sampled = nuisances_are_optimized @@ -521,11 +507,9 @@ def __init__(self, elif nstates!=len(IMP.atom.get_by_type(self.root,IMP.atom.STATE_TYPE)): print("Warning: nstates is not the same as the number of states in root") - self.rs = IMP.RestraintSet(self.mdl, 'likelihood') - self.rs_psi = IMP.RestraintSet(self.mdl, 'prior_psi') - self.rs_sig = IMP.RestraintSet(self.mdl, 'prior_sigmas') - self.rs_lin = IMP.RestraintSet(self.mdl, 'linear_dummy_restraints') - + self.rs_psi = self._create_restraint_set("psi") + self.rs_sig = self._create_restraint_set("sigma") + self.rs_lin = self._create_restraint_set("linear") self.psi_dictionary = {} self.sigma_dictionary = {} @@ -558,9 +542,9 @@ def __init__(self, # Will setup two sigmas based on promiscuity of the residue sig_threshold=4 - self.sig_low = setup_nuisance(self.mdl,self.rs_nuis,init_val=sigma_init,min_val=1.0, + self.sig_low = setup_nuisance(self.m,self.rs_nuis,init_val=sigma_init,min_val=1.0, max_val=100.0,is_opt=self.nuis_opt) - self.sig_high = setup_nuisance(self.mdl,self.rs_nuis,init_val=sigma_init,min_val=1.0, + self.sig_high = setup_nuisance(self.m,self.rs_nuis,init_val=sigma_init,min_val=1.0, max_val=100.0,is_opt=self.nuis_opt) ''' self._create_sigma('sigma',sigma_init) @@ -578,7 +562,7 @@ def __init__(self, psip = self.psi_dictionary['psi'][0].get_particle_index() else: psip = self.psi_dictionary[unique_id][0].get_particle_index() - r = IMP.isd.AtomicCrossLinkMSRestraint(self.mdl, + r = IMP.isd.AtomicCrossLinkMSRestraint(self.m, self.length, psip, slope, @@ -648,26 +632,15 @@ def __init__(self, if len(xlrs)==0: raise Exception("You didn't create any XL restraints") print('created',len(xlrs),'XL restraints') - self.rs=IMP.isd.LogWrapper(xlrs,self.weight) - - def set_weight(self,weight): - self.weight = weight - self.rs.set_weight(weight) - - def set_label(self, label): - self.label = label - - def add_to_model(self): - IMP.pmi.tools.add_restraint_to_model(self.mdl, self.rs) - IMP.pmi.tools.add_restraint_to_model(self.mdl, self.rs_sig) - IMP.pmi.tools.add_restraint_to_model(self.mdl, self.rs_psi) + rname = self.rs.get_name() + self.rs=IMP.isd.LogWrapper(xlrs, self.weight) + self.rs.set_name(rname) + self.rs.set_weight(self.weight) + self.restraint_sets = [self.rs] + self.restraint_sets[1:] def get_hierarchy(self): return self.prot - def get_restraint_set(self): - return self.rs - def _create_sigma(self, name,sigmainit): """ This is called internally. Creates a nuisance on the structural uncertainty """ @@ -679,7 +652,7 @@ def _create_sigma(self, name,sigmainit): sigmamin = 0.01 sigmamax = 100.0 sigmatrans = 0.5 - sigma = IMP.pmi.tools.SetupNuisance(self.mdl, + sigma = IMP.pmi.tools.SetupNuisance(self.m, sigmainit, sigmaminnuis, sigmamaxnuis, @@ -690,7 +663,7 @@ def _create_sigma(self, name,sigmainit): self.sigma_is_sampled) self.rs_sig.add_restraint( IMP.isd.UniformPrior( - self.mdl, + self.m, sigma, 1000000000.0, sigmamax, @@ -708,7 +681,7 @@ def _create_psi(self, name,psiinit): psimin = 0.01 psimax = 0.49 psitrans = 0.1 - psi = IMP.pmi.tools.SetupNuisance(self.mdl, + psi = IMP.pmi.tools.SetupNuisance(self.m, psiinit, psiminnuis, psimaxnuis, @@ -720,13 +693,13 @@ def _create_psi(self, name,psiinit): self.rs_psi.add_restraint( IMP.isd.UniformPrior( - self.mdl, + self.m, psi, 1000000000.0, psimax, psimin)) - self.rs_psi.add_restraint(IMP.isd.JeffreysRestraint(self.mdl, psi)) + self.rs_psi.add_restraint(IMP.isd.JeffreysRestraint(self.m, psi)) return psi def create_restraints_for_rmf(self): @@ -747,7 +720,7 @@ def get_restraint_for_rmf(self): rs = IMP.RestraintSet(dummy_mdl, 'atomic_xl_'+str(nxl)) for ncontr in range(xl.get_number_of_contributions()): ps=xl.get_contribution(ncontr) - dr = IMP.core.PairRestraint(hps,[self.mdl.get_particle(p) for p in ps], + dr = IMP.core.PairRestraint(hps,[self.m.get_particle(p) for p in ps], 'xl%i_contr%i'%(nxl,ncontr)) rs.add_restraint(dr) dummy_rs.append(MyGetRestraint(rs)) @@ -759,13 +732,14 @@ def get_particles_to_sample(self): ps = {} if self.sigma_is_sampled: for sigmaname in self.sigma_dictionary: - ps["Nuisances_AtomicCrossLinkingMSRestraint_Sigma_" + str(sigmaname) + "_" + self.label] = \ + ps["Nuisances_AtomicCrossLinkingMSRestraint_Sigma_" + + str(sigmaname) + self._label_suffix] = \ ([self.sigma_dictionary[sigmaname][0]], self.sigma_dictionary[sigmaname][1]) if self.psi_is_sampled: for psiname in self.psi_dictionary: ps["Nuisances_CrossLinkingMassSpectrometryRestraint_Psi_" + - str(psiname) + "_" + self.label] =\ + str(psiname) + self._label_suffix] =\ ([self.psi_dictionary[psiname][0]], self.psi_dictionary[psiname][1]) return ps @@ -793,10 +767,10 @@ def load_nuisances_from_stat_file(self,in_fn,nframe): for nxl in range(self.rs.get_number_of_restraints()): xl=IMP.isd.AtomicCrossLinkMSRestraint.get_from(self.rs.get_restraint(nxl)) psip = xl.get_psi() - IMP.isd.Scale(self.mdl,psip).set_scale(psi_val) + IMP.isd.Scale(self.m,psip).set_scale(psi_val) for contr in range(xl.get_number_of_contributions()): sig1,sig2=xl.get_contribution_sigmas(contr) - IMP.isd.Scale(self.mdl,sig1).set_scale(sig_val) + IMP.isd.Scale(self.m,sig1).set_scale(sig_val) print('loaded nuisances from file') @@ -877,13 +851,13 @@ def plot_violations(self,out_prefix, r=0.365; g=0.933; b=0.365; # now only showing if UNIQUELY PASSING in this state pp = state_info[nstate][nxl]["low_pp"] - c1=IMP.core.XYZ(self.mdl,pp[0]).get_coordinates() - c2=IMP.core.XYZ(self.mdl,pp[1]).get_coordinates() + c1=IMP.core.XYZ(self.m,pp[0]).get_coordinates() + c2=IMP.core.XYZ(self.m,pp[1]).get_coordinates() - r1 = IMP.atom.get_residue(IMP.atom.Atom(self.mdl,pp[0])).get_index() - ch1 = IMP.atom.get_chain_id(IMP.atom.Atom(self.mdl,pp[0])) - r2 = IMP.atom.get_residue(IMP.atom.Atom(self.mdl,pp[0])).get_index() - ch2 = IMP.atom.get_chain_id(IMP.atom.Atom(self.mdl,pp[0])) + r1 = IMP.atom.get_residue(IMP.atom.Atom(self.m,pp[0])).get_index() + ch1 = IMP.atom.get_chain_id(IMP.atom.Atom(self.m,pp[0])) + r2 = IMP.atom.get_residue(IMP.atom.Atom(self.m,pp[0])).get_index() + ch2 = IMP.atom.get_chain_id(IMP.atom.Atom(self.m,pp[0])) cmds[nstate].add((ch1,r1)) cmds[nstate].add((ch2,r2)) @@ -912,12 +886,12 @@ def _get_contribution_info(self,xl,ncontr,use_CA=False): idx1=xl.get_contribution(ncontr)[0] idx2=xl.get_contribution(ncontr)[1] if use_CA: - idx1 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.mdl,idx1)), + idx1 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.m,idx1)), atom_type=IMP.atom.AtomType("CA")).get_selected_particle_indexes()[0] - idx2 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.mdl,idx2)), + idx2 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.m,idx2)), atom_type=IMP.atom.AtomType("CA")).get_selected_particle_indexes()[0] - dist = IMP.algebra.get_distance(IMP.core.XYZ(self.mdl,idx1).get_coordinates(), - IMP.core.XYZ(self.mdl,idx2).get_coordinates()) + dist = IMP.algebra.get_distance(IMP.core.XYZ(self.m,idx1).get_coordinates(), + IMP.core.XYZ(self.m,idx2).get_coordinates()) return idx1,idx2,dist def get_best_stats(self,limit_to_state=None,limit_to_chains=None,exclude_chains='',use_CA=False): @@ -941,21 +915,21 @@ def get_best_stats(self,limit_to_state=None,limit_to_chains=None,exclude_chains= for contr in range(xl.get_number_of_contributions()): pp = xl.get_contribution(contr) if use_CA: - idx1 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.mdl,pp[0])), + idx1 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.m,pp[0])), atom_type=IMP.atom.AtomType("CA")).get_selected_particle_indexes()[0] - idx2 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.mdl,pp[1])), + idx2 = IMP.atom.Selection(IMP.atom.get_residue(IMP.atom.Atom(self.m,pp[1])), atom_type=IMP.atom.AtomType("CA")).get_selected_particle_indexes()[0] pp = [idx1,idx2] if limit_to_state is not None: - nstate = IMP.atom.get_state_index(IMP.atom.Atom(self.mdl,pp[0])) + nstate = IMP.atom.get_state_index(IMP.atom.Atom(self.m,pp[0])) if nstate!=limit_to_state: continue state_contrs.append(contr) - dist = IMP.core.get_distance(IMP.core.XYZ(self.mdl,pp[0]), - IMP.core.XYZ(self.mdl,pp[1])) + dist = IMP.core.get_distance(IMP.core.XYZ(self.m,pp[0]), + IMP.core.XYZ(self.m,pp[1])) if limit_to_chains is not None: - c1 = IMP.atom.get_chain_id(IMP.atom.Atom(self.mdl,pp[0])) - c2 = IMP.atom.get_chain_id(IMP.atom.Atom(self.mdl,pp[1])) + c1 = IMP.atom.get_chain_id(IMP.atom.Atom(self.m,pp[0])) + c2 = IMP.atom.get_chain_id(IMP.atom.Atom(self.m,pp[1])) if (c1 in limit_to_chains or c2 in limit_to_chains) and ( c1 not in exclude_chains and c2 not in exclude_chains): if dist