diff --git a/contact_map/set_dict.py b/contact_map/set_dict.py new file mode 100644 index 0000000..0030b73 --- /dev/null +++ b/contact_map/set_dict.py @@ -0,0 +1,124 @@ +import collections + +try: + from collections import abc +except ImportError: + abc = collections # Py 2.7 + +from mdtraj.core.topology import Atom, Residue + +""" +Classes that use frozensets at keys, but allow access with by any iterable. + +Contact maps frequently require mappings of pairs of objects (representing +the contact pair) to some value. Since the order of the objects in the pair +is unimportant (the pair (A,B) is the same as (B,A)), we use a ``set``. +However, since these are keys, the pair must be immutable: a ``frozenset``. +It gets really annoying to have to type ``frozenset`` around each object, so +the classes in this module allow other iterables (tuples, lists) to be used +as keys in getting/setting items -- internally, they are converted to +``frozenset``. +""" + + +class FrozenSetDict(abc.MutableMapping): + """Dictionary-like object that uses frozensets internally. + + Note + ---- + This can take input like ``dict({key: value})`` or + ``dict([(key, value)])``, but not like ``dict(key=value)``, for the + simple reason that in the last case, you can't use an iterable as key. + """ + + hash_map = frozenset + def __init__(self, input_data=None): + self.dct = {} + if input_data is not None: + if isinstance(input_data, collections.Mapping): + # convert the mapping to key-value tuples + input_data = list(input_data.items()) + + for key, value in input_data: + self.dct[self._regularize_key(key)] = value + + def __len__(self): + return len(self.dct) + + def __iter__(self): + return iter(self.dct) + + def _regularize_key(self, key): + def all_isinstance(iterable, cls): + return all(isinstance(k, cls) for k in iterable) + + if all_isinstance(key, Atom) or all_isinstance(key, Residue): + key = self.hash_map(k.index for k in key) + else: + key = self.hash_map(key) + + return key + + def __getitem__(self, key): + return self.dct[self._regularize_key(key)] + + def __setitem__(self, key, value): + self.dct[self._regularize_key(key)] = value + + def __delitem__(self, key): + del self.dct[self._regularize_key(key)] + + +def _make_frozen_set_counter(other): + if not isinstance(other, FrozenSetCounter): + other = FrozenSetCounter(other) + return other + + +class FrozenSetCounter(FrozenSetDict): + """Counter-like object that uses frozensets internally. + """ + def __init__(self, input_data=None): + if input_data is None: + input_data = [] + + if not isinstance(input_data, collections.Mapping): + self.counter = collections.Counter([ + self._regularize_key(inp) + for inp in input_data + ]) + else: + self.counter = collections.Counter({ + self._regularize_key(key): value + for key, value in input_data.items() + }) + + def most_common(self, n=None): + return self.counter.most_common(n) + + def elements(self): + return self.counter.elements() + + def subtract(self, other): + other = _make_frozen_set_counter(other) + self.counter.subtract(other.counter) + + def update(self, other): + other = _make_frozen_set_counter(other) + self.counter.update(other.counter) + + def __add__(self, other): + other = _make_frozen_set_counter(other) + counter = self.counter + other.counter + return FrozenSetCounter(counter) + + def __sub__(self, other): + other = _make_frozen_set_counter(other) + counter = self.counter - other.counter + return FrozenSetCounter(counter) + + def __and__(self, other): + pass + + def __or__(self, other): + pass diff --git a/contact_map/tests/test_set_dict.py b/contact_map/tests/test_set_dict.py new file mode 100644 index 0000000..8c00be8 --- /dev/null +++ b/contact_map/tests/test_set_dict.py @@ -0,0 +1,87 @@ +import itertools + +# pylint: disable=wildcard-import, missing-docstring, protected-access +# pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use +# pylint: disable=wrong-import-order, unused-wildcard-import + +# includes pytest +from .utils import * + +from contact_map.set_dict import * +from .test_contact_map import traj + +KEY_ITER_IDX = list(itertools.product(['list', 'tuple', 'fset'], + ['idx', 'obj'])) + + +def make_key(obj_type, iter_type, idx_to_type, idx_pair): + top = traj.topology + idx_to_type_f = { + 'idx': lambda idx: idx, + 'obj': {'atom': top.atom, + 'res': top.residue}[obj_type] + }[idx_to_type] + iter_type_f = {'list': list, + 'tuple': tuple, + 'fset': frozenset}[iter_type] + key = iter_type_f(idx_to_type_f(idx) for idx in idx_pair) + return key + + +@pytest.mark.parametrize("obj_type", ['atom', 'res']) +class TestFrozenSetDict(object): + def setup(self): + topology = traj.topology + self.expected_dct = { + frozenset([0, 1]): 10, + frozenset([1, 2]): 5 + } + self.atom_fsdict, self.residue_fsdct = [ + FrozenSetDict({(fcn(0), fcn(1)): 10, + (fcn(1), fcn(2)): 5}) + for fcn in [topology.atom, topology.residue] + ] + + def test_init(self, obj_type): + obj = {'atom': self.atom_fsdict, + 'res': self.residue_fsdct}[obj_type] + assert obj.dct == self.expected_dct + + def test_len(self, obj_type): + obj = {'atom': self.atom_fsdict, + 'res': self.residue_fsdct}[obj_type] + assert len(obj) == 2 + + def test_iter(self, obj_type): + obj = {'atom': self.atom_fsdict, + 'res': self.residue_fsdct}[obj_type] + for k in obj: + assert k in [frozenset([0, 1]), frozenset([1, 2])] + + @pytest.mark.parametrize("iter_type, idx_to_type", KEY_ITER_IDX) + def test_get(self, obj_type, iter_type, idx_to_type): + obj = {'atom': self.atom_fsdict, + 'res': self.residue_fsdct}[obj_type] + key = make_key(obj_type, iter_type, idx_to_type, [0, 1]) + assert obj[key] == 10 + + @pytest.mark.parametrize("iter_type, idx_to_type", KEY_ITER_IDX) + def test_set(self, obj_type, iter_type, idx_to_type): + obj = {'atom': self.atom_fsdict, + 'res': self.residue_fsdct}[obj_type] + key = make_key(obj_type, iter_type, idx_to_type, [1, 3]) + obj[key] = 20 + assert obj.dct[frozenset([1, 3])] == 20 + + @pytest.mark.parametrize("iter_type, idx_to_type", KEY_ITER_IDX) + def test_del(self, obj_type, iter_type, idx_to_type): + obj = {'atom': self.atom_fsdict, + 'res': self.residue_fsdct}[obj_type] + key = make_key(obj_type, iter_type, idx_to_type, [0, 1]) + del obj[key] + assert len(obj) == 1 + assert list(obj.dct.keys()) == [frozenset([1, 2])] + + +class TestFrozenSetCounter(object): + pass