diff --git a/contact_map/contact_map.py b/contact_map/contact_map.py index 98cccff..1e39c28 100644 --- a/contact_map/contact_map.py +++ b/contact_map/contact_map.py @@ -13,12 +13,78 @@ import numpy as np import pandas as pd import mdtraj as md +from mdtraj.core.topology import Residue from .contact_count import ContactCount from .atom_indexer import AtomSlicedIndexer, IdentityIndexer from .py_2_3 import inspect_method_arguments from .fix_parameters import ParameterFixer +try: + from cython import compiled +except ImportError: + compiled = False + +if compiled: + from cython import ( + Py_hash_t, + Py_ssize_t, + bint, + cast, + ccall, + cclass, + cfunc, + declare, + double, + exceptval, + final, + inline, + nogil, + ) +else: + from ctypes import c_double as double + from ctypes import c_ssize_t as Py_hash_t + from ctypes import c_ssize_t as Py_ssize_t + + bint = bool + + def cast(T, v, *a, **k): + return v + + def ccall(func): + return func + + def cclass(cls): + return cls + + def cfunc(func): + return func + + def declare(*a, **k): + if len(a) == 2: + return a[1] + else: + pass + + def exceptval(*a, **k): + def wrapper(func): + return func + + return wrapper + + def final(cls): + return cls + + def inline(func): + return func + + def nogil(func): + return func + +COMPILED = declare(bint, compiled) +globals()["COMPILED"] = COMPILED + + # TODO: # * switch to something where you can define the haystack -- the trick is to # replace the current mdtraj._compute_neighbors with something that @@ -59,18 +125,19 @@ def residue_neighborhood(residue, n=1): return [idx for idx in neighborhood if idx in chain] -def _residue_for_atom(topology, atom_list): +def _residue_for_atom(topology, atom_list) -> set: return set([topology.atom(a).residue for a in atom_list]) -def _residue_idx_for_atom(topology, atom_list): +def _residue_idx_for_atom(topology, atom_list) -> set: return set([topology.atom(a).residue.index for a in atom_list]) -def _range_from_iterable(iterable): +def _range_from_iterable(iterable) -> tuple: sort = sorted(iterable) return (sort[0], sort[-1]+1) +deserialize_atom_to_residue_dct = lambda d: {int(k): d[k] for k in d} class ContactsDict(object): """Dict-like object giving access to atom or residue contacts. @@ -104,7 +171,7 @@ def __getitem__(self, atom_or_res): str(atom_or_res)) return contacts - +@cclass class ContactObject(object): """ Generic object for contact map related analysis. Effectively abstract. @@ -112,11 +179,29 @@ class ContactObject(object): Much of what we need to do the contact map analysis is the same for all analyses. It's in here. """ - # Class default for use atom slice, None tries to be smart _class_use_atom_slice = None - def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored): + # cython typing stuff needed for dask integration + # (visibility=public is needed for from_dict()) + _topology = declare(object, visibility='public') + _cutoff = declare(double, visibility='public') + _query = declare(set, visibility='public') + _haystack = declare(set, visibility='public') + _query_res_idx = declare(set, visibility='public') + _haystack_res_idx = declare(set, visibility='public') + _all_atoms = declare(tuple, visibility='public') + _all_residues = declare(set, visibility='public') + _use_atom_slice = declare(bint, visibility='public') + indexer = declare(object, visibility="public") + _n_neighbors_ignored = declare(Py_ssize_t, visibility='public') + + def __init__(self, + topology: md.Topology, + query: object, + haystack: object, + cutoff: double, + n_neighbors_ignored: Py_ssize_t): # all inits required: no defaults for abstract class! self._topology = topology @@ -140,18 +225,24 @@ def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored): all_atoms_set) self._use_atom_slice = self._set_atom_slice(self._all_atoms) has_indexer = getattr(self, 'indexer', None) is not None + if not has_indexer: Indexer = {True: AtomSlicedIndexer, False: IdentityIndexer}[self.use_atom_slice] self.indexer = Indexer(topology, self._query, self._haystack, self._all_atoms) - self._n_neighbors_ignored = n_neighbors_ignored @classmethod - def from_contacts(cls, atom_contacts, residue_contacts, topology, - query=None, haystack=None, cutoff=0.45, - n_neighbors_ignored=2, indexer=None): + def from_contacts(cls, + atom_contacts: collections.Counter, + residue_contacts: collections.Container, + topology: md.Topology, + query: object = None, + haystack: object = None, + cutoff: double = 0.45, + n_neighbors_ignored: Py_ssize_t = 2, + indexer: object = None) -> object: obj = cls.__new__(cls) obj.indexer = indexer super(cls, obj).__init__(topology, query, haystack, cutoff, @@ -167,7 +258,7 @@ def get_contact_counter(contact): obj._residue_contacts = get_contact_counter(residue_contacts) return obj - def _set_atom_slice(self, all_atoms): + def _set_atom_slice(self, all_atoms: tuple) -> bint: """ Set atom slice logic """ if (self._class_use_atom_slice is None and not len(all_atoms) < self._topology.n_atoms): @@ -181,16 +272,16 @@ def _set_atom_slice(self, all_atoms): return self._class_use_atom_slice @property - def contacts(self): + def contacts(self) -> ContactsDict: """:class:`.ContactsDict` : contact dict for these contacts""" return ContactsDict(self) - def __hash__(self): + def __hash__(self) -> Py_hash_t: return hash((self.cutoff, self.n_neighbors_ignored, frozenset(self._query), frozenset(self._haystack), self.topology)) - def __eq__(self, other): + def __eq__(self, other: object) -> bint: is_equal = (self.cutoff == other.cutoff and self.n_neighbors_ignored == other.n_neighbors_ignored and self.query == other.query @@ -198,7 +289,7 @@ def __eq__(self, other): and self.topology == other.topology) return is_equal - def to_dict(self): + def to_dict(self) -> dict: """Convert object to a dict. Keys should be strings; values should be (JSON-) serializable. @@ -230,7 +321,7 @@ def to_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> object: """Create object from dict. Parameters @@ -243,7 +334,7 @@ def from_dict(cls, dct): to_dict """ deserialize_set = set - deserialize_atom_to_residue_dct = lambda d: {int(k): d[k] for k in d} + #deserialize_atom_to_residue_dct = lambda d: {int(k): d[k] for k in d} deserialization_helpers = { 'topology': cls._deserialize_topology, 'atom_contacts': cls._deserialize_contact_counter, @@ -252,7 +343,7 @@ def from_dict(cls, dct): 'haystack': deserialize_set, 'query_res_idx': deserialize_set, 'haystack_res_idx': deserialize_set, - 'all_atoms': deserialize_set, + 'all_atoms': tuple, 'all_residues': deserialize_set, 'atom_idx_to_residue_idx': deserialize_atom_to_residue_dct } @@ -270,7 +361,7 @@ def from_dict(cls, dct): return instance @staticmethod - def _deserialize_topology(topology_json): + def _deserialize_topology(topology_json: str) -> md.Topology: """Create MDTraj topology from JSON-serialized version""" table, bonds = json.loads(topology_json) topology_df = pd.read_json(table) @@ -279,7 +370,7 @@ def _deserialize_topology(topology_json): return topology @staticmethod - def _serialize_topology(topology): + def _serialize_topology(topology: md.Topology) -> str: """Serialize MDTraj topology (to JSON)""" table, bonds = topology.to_dataframe() json_tuples = (table.to_json(), bonds.tolist()) @@ -288,7 +379,7 @@ def _serialize_topology(topology): # TODO: adding a separate object for these frozenset counters will be # useful for many things, and this serialization should be moved there @staticmethod - def _serialize_contact_counter(counter): + def _serialize_contact_counter(counter: collections.Container) -> str: """JSON string from contact counter""" # have to explicitly convert to int because json doesn't know how to # serialize np.int64 objects, which we get in Python 3 @@ -297,7 +388,7 @@ def _serialize_contact_counter(counter): return json.dumps(serializable) @staticmethod - def _deserialize_contact_counter(json_string): + def _deserialize_contact_counter(json_string: str) -> collections.Counter: """Contact counted from JSON string""" dct = json.loads(json_string) counter = collections.Counter({ @@ -305,7 +396,7 @@ def _deserialize_contact_counter(json_string): }) return counter - def to_json(self): + def to_json(self) -> str: """JSON-serialized version of this object. See also @@ -316,7 +407,7 @@ def to_json(self): return json.dumps(dct) @classmethod - def from_json(cls, json_string): + def from_json(cls, json_string: str) -> object: """Create object from JSON string Parameters @@ -352,7 +443,7 @@ def _check_compatibility(self, other, err=AssertionError): else: return failed_attr - def save_to_file(self, filename, mode="w"): + def save_to_file(self, filename: str, mode: str = "w"): """Save this object to the given file. Parameters @@ -371,7 +462,7 @@ def save_to_file(self, filename, mode="w"): pickle.dump(self, f) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str) -> object: """Load this object from a given file Parameters @@ -392,36 +483,36 @@ def from_file(cls, filename): reloaded = pickle.load(f) return reloaded - def __sub__(self, other): + def __sub__(self, other) -> "ContactDifference": return ContactDifference(positive=self, negative=other) @property - def cutoff(self): + def cutoff(self) -> double: """float : cutoff distance for contacts, in nanometers""" return self._cutoff @property - def n_neighbors_ignored(self): + def n_neighbors_ignored(self) -> Py_ssize_t: """int : number of neighbor residues (in same chain) to ignore""" return self._n_neighbors_ignored @property - def query(self): + def query(self) -> list: """list of int : indices of atoms to include as query""" return list(self._query) @property - def haystack(self): + def haystack(self) -> list: """list of int : indices of atoms to include as haystack""" return list(self._haystack) @property - def all_atoms(self): + def all_atoms(self) -> list: """list of int: all atom indices used in the contact map""" return list(self._all_atoms) @property - def topology(self): + def topology(self) -> md.Topology: """ :class:`mdtraj.Topology` : topology object for this system @@ -433,13 +524,13 @@ def topology(self): return self._topology @property - def use_atom_slice(self): + def use_atom_slice(self) -> bint: """bool : Indicates if `mdtraj.atom_slice()` is used before calculating the contact map""" return self._use_atom_slice @property - def _residue_ignore_atom_idxs(self): + def _residue_ignore_atom_idxs(self) -> dict: """dict : maps query residue index to atom indices to ignore""" all_atoms_set = set(self._all_atoms) result = {} @@ -460,36 +551,37 @@ def _residue_ignore_atom_idxs(self): return result @property - def haystack_residues(self): + def haystack_residues(self) -> list: """list : residues for atoms in the haystack""" return _residue_for_atom(self.topology, self.haystack) @property - def query_residues(self): + def query_residues(self) -> list: """list : residues for atoms in the query""" return _residue_for_atom(self.topology, self.query) @property - def query_range(self): + def query_range(self) -> tuple: """return an tuple with the (min, max+1) of query""" return _range_from_iterable(self.query) @property - def haystack_range(self): + def haystack_range(self) -> tuple: """return an tuple with the (min, max+1) of haystack""" return _range_from_iterable(self.haystack) @property - def haystack_residue_range(self): + def haystack_residue_range(self) -> tuple: """(int, int): min and (max + 1) of haystack residue indices""" return _range_from_iterable(self._haystack_res_idx) @property - def query_residue_range(self): + def query_residue_range(self) -> tuple: """(int, int): min and (max + 1) of query residue indices""" return _range_from_iterable(self._query_res_idx) - def most_common_atoms_for_residue(self, residue): + def most_common_atoms_for_residue(self, + residue: Residue) -> list: """ Most common atom contact pairs for contacts with the given residue @@ -519,7 +611,9 @@ def most_common_atoms_for_residue(self, residue): return results - def most_common_atoms_for_contact(self, contact_pair): + def most_common_atoms_for_contact(self, + contact_pair: "list[Residue, Residue]" + ) -> list: """ Most common atom contacts for a given residue contact pair @@ -552,8 +646,12 @@ def most_common_atoms_for_contact(self, contact_pair): if frozenset(contact[0]) in all_atom_pairs] return result - def _contact_map(self, trajectory, frame_number, residue_query_atom_idxs, - residue_ignore_atom_idxs): + def _contact_map(self, + trajectory: md.Trajectory, + frame_number: Py_ssize_t, + residue_query_atom_idxs: collections.defaultdict, + residue_ignore_atom_idxs: dict + ) -> "tuple[collections.Counter, collections.Counter]": """ Returns atom and residue contact maps for the given frame. @@ -571,46 +669,50 @@ def _contact_map(self, trajectory, frame_number, residue_query_atom_idxs, """ used_trajectory = self.indexer.slice_trajectory(trajectory) - neighborlist = md.compute_neighborlist(used_trajectory, self.cutoff, - frame_number) + neighborlist: list = md.compute_neighborlist(used_trajectory, + self.cutoff, + frame_number) - contact_pairs = set([]) - residue_pairs = set([]) - haystack = self.indexer.haystack - atom_idx_to_residue_idx = self.indexer.atom_idx_to_residue_idx + contact_pairs: set = set() + residue_pairs: set = set() + haystack: set = self.indexer.haystack + atom_idx_to_residue_idx: dict = self.indexer.atom_idx_to_residue_idx + residue_idx: Py_ssize_t for residue_idx in residue_query_atom_idxs: - ignore_atom_idxs = set(residue_ignore_atom_idxs[residue_idx]) - query_idxs = residue_query_atom_idxs[residue_idx] + ignore_atom_idxs: set = set(residue_ignore_atom_idxs[residue_idx]) + query_idxs: list = residue_query_atom_idxs[residue_idx] + atom_idx: Py_ssize_t for atom_idx in query_idxs: # sets should make this fast, esp since neighbor_idxs # should be small and s-t is avg cost len(s) - neighbor_idxs = set(neighborlist[atom_idx]) - contact_neighbors = neighbor_idxs - ignore_atom_idxs - contact_neighbors = contact_neighbors & haystack + neighbor_idxs: set = set(neighborlist[atom_idx]) + contact_neighbors: set = neighbor_idxs - ignore_atom_idxs + contact_neighbors: set = contact_neighbors & haystack # frozenset is unique key independent of order # local_pairs = set(frozenset((atom_idx, neighb)) # for neighb in contact_neighbors) - local_pairs = set(map( + local_pairs: set = set(map( frozenset, itertools.product([atom_idx], contact_neighbors) )) contact_pairs |= local_pairs # contact_pairs |= set(frozenset((atom_idx, neighb)) # for neighb in contact_neighbors) - local_residue_partners = set(atom_idx_to_residue_idx[a] - for a in contact_neighbors) + local_residue_partners: set = set(atom_idx_to_residue_idx[a] + for a in contact_neighbors) local_res_pairs = set(map( frozenset, itertools.product([residue_idx], local_residue_partners) )) residue_pairs |= local_res_pairs - atom_contacts = collections.Counter(contact_pairs) + atom_contacts: collections.Counter = collections.Counter(contact_pairs) # residue_pairs = set( # frozenset(self._atom_idx_to_residue_idx[aa] for aa in pair) # for pair in contact_pairs # ) - residue_contacts = collections.Counter(residue_pairs) + residue_contacts: collections.Counter = collections.Counter( + residue_pairs) return (atom_contacts, residue_contacts) @property @@ -663,6 +765,7 @@ class ContactFrequency(ContactObject): "more, see https://github.com/dwhswenson/contact_map/issues/82" ) + def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45, n_neighbors_ignored=2): warnings.warn(self._pending_dep_msg, PendingDeprecationWarning) diff --git a/contact_map/contact_trajectory.py b/contact_map/contact_trajectory.py index 046206d..0fce69b 100644 --- a/contact_map/contact_trajectory.py +++ b/contact_map/contact_trajectory.py @@ -104,6 +104,9 @@ def __hash__(self): def __eq__(self, other): return hash(self) == hash(other) + def __ne__(self, other): + return not self.__eq__(other) + @classmethod def from_contacts(cls, atom_contacts, residue_contacts, topology, query=None, haystack=None, cutoff=0.45, diff --git a/contact_map/fix_parameters.py b/contact_map/fix_parameters.py index 3c1cac2..bbfe9ed 100644 --- a/contact_map/fix_parameters.py +++ b/contact_map/fix_parameters.py @@ -32,8 +32,8 @@ def _fix_parameters(self, map0, map1, failed): map1_set = set(getattr(map1, fail)) fixed = getattr(map0_set, self._set_mixing)(map1_set) elif fail in {'cutoff', 'n_neighbors_ignored'}: - # We just set them to None - fixed = None + # We just set them to -1None + fixed = -1 elif fail == 'topology': # This requires quite a bit of logic fixed = self._check_topology(map0, map1) diff --git a/contact_map/tests/test_contact_map.py b/contact_map/tests/test_contact_map.py index 60f4f3d..3b34117 100644 --- a/contact_map/tests/test_contact_map.py +++ b/contact_map/tests/test_contact_map.py @@ -752,7 +752,7 @@ def test_non_important_attributes(self, attr): if attr[0] in {'query','haystack'}: assert getattr(diff, attr[0]) == attr[1] else: - assert getattr(diff, attr[0]) is None + assert getattr(diff, attr[0]) == -1 # Make sure we can still do the maps assert diff.atom_contacts is not None assert diff.residue_contacts is not None diff --git a/contact_map/tests/test_cython.py b/contact_map/tests/test_cython.py new file mode 100644 index 0000000..243f601 --- /dev/null +++ b/contact_map/tests/test_cython.py @@ -0,0 +1,9 @@ +import pytest + +cython = pytest.importorskip("cython") +from contact_map.contact_map import COMPILED + + +@pytest.mark.skipif(not COMPILED, reason="no compiled code") +def test_cythonization(): + assert COMPILED diff --git a/contact_map/tests/test_dask_runner.py b/contact_map/tests/test_dask_runner.py index d704db5..4a8bedd 100644 --- a/contact_map/tests/test_dask_runner.py +++ b/contact_map/tests/test_dask_runner.py @@ -76,8 +76,8 @@ def test_dask_atom_slice(self): (DaskContactTrajectory, ContactTrajectory)]) def test_answer_equal(self, dask_cls, norm_cls): trj = mdtraj.load(self.filename) - dask_result = dask_cls(self.client, self.filename) norm_result = norm_cls(trj) + dask_result = dask_cls(self.client, self.filename) if isinstance(dask_result, Iterable): for i, j in zip(dask_result, norm_result): assert i.atom_contacts._counter == j.atom_contacts._counter diff --git a/setup.py b/setup.py index a246dfb..bd7ac33 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ import fnmatch # Py 2 from setuptools import setup +from setuptools.extension import Extension def _glob_glob_recursive(directory, pattern): # python 2 glob.glob doesn't have a recursive keyword @@ -117,9 +118,44 @@ def write_installed_version_py(filename="_installed_version.py", with open (os.path.join(src_dir, filename), 'w') as f: f.write(content.format(vers=version, git=git_rev, depth=depth)) +def cythonize(): + # Copied from dask/distributed + cython_arg = None + for i in range(len(sys.argv)): + if sys.argv[i].startswith("--with-cython"): + cython_arg = sys.argv[i] + del sys.argv[i] + break + if not cython_arg: + return [] + try: + import cython + except ImportError: + return [] + + ext_modules = [] + profile = False + cyext_modules = [ + Extension( + "contact_map.contact_map", + sources=["contact_map/contact_map.py"], + ), + ] + for e in cyext_modules: + e.cython_directives = { + "annotation_typing": True, + "binding": False, + "embedsignature": True, + "language_level": 3, + "profile": profile, + } + ext_modules.extend(cyext_modules) + return ext_modules + if __name__ == "__main__": # TODO: only write version.py under special circumstances write_installed_version_py() # write_version_py(os.path.join('autorelease', 'version.py')) - setup() + ext_modules=cythonize() + setup(ext_modules=ext_modules)