Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] FrozenSetDict; FrozenSetCounter #42

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions contact_map/set_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import collections

try:
from collections import abc
except ImportError:
abc = collections # Py 2.7

from mdtraj.core.topology import Atom, Residue

"""
Classes that use frozensets at keys, but allow access with by any iterable.

Contact maps frequently require mappings of pairs of objects (representing
the contact pair) to some value. Since the order of the objects in the pair
is unimportant (the pair (A,B) is the same as (B,A)), we use a ``set``.
However, since these are keys, the pair must be immutable: a ``frozenset``.
It gets really annoying to have to type ``frozenset`` around each object, so
the classes in this module allow other iterables (tuples, lists) to be used
as keys in getting/setting items -- internally, they are converted to
``frozenset``.
"""


class FrozenSetDict(abc.MutableMapping):
"""Dictionary-like object that uses frozensets internally.

Note
----
This can take input like ``dict({key: value})`` or
``dict([(key, value)])``, but not like ``dict(key=value)``, for the
simple reason that in the last case, you can't use an iterable as key.
"""

hash_map = frozenset
def __init__(self, input_data=None):
self.dct = {}
if input_data is not None:
if isinstance(input_data, collections.Mapping):
# convert the mapping to key-value tuples
input_data = list(input_data.items())

for key, value in input_data:
self.dct[self._regularize_key(key)] = value

def __len__(self):
return len(self.dct)

def __iter__(self):
return iter(self.dct)

def _regularize_key(self, key):
def all_isinstance(iterable, cls):
return all(isinstance(k, cls) for k in iterable)

if all_isinstance(key, Atom) or all_isinstance(key, Residue):
key = self.hash_map(k.index for k in key)
else:
key = self.hash_map(key)

return key

def __getitem__(self, key):
return self.dct[self._regularize_key(key)]

def __setitem__(self, key, value):
self.dct[self._regularize_key(key)] = value

def __delitem__(self, key):
del self.dct[self._regularize_key(key)]


def _make_frozen_set_counter(other):
if not isinstance(other, FrozenSetCounter):
other = FrozenSetCounter(other)
return other


class FrozenSetCounter(FrozenSetDict):
"""Counter-like object that uses frozensets internally.
"""
def __init__(self, input_data=None):
if input_data is None:
input_data = []

if not isinstance(input_data, collections.Mapping):
self.counter = collections.Counter([
self._regularize_key(inp)
for inp in input_data
])
else:
self.counter = collections.Counter({
self._regularize_key(key): value
for key, value in input_data.items()
})

def most_common(self, n=None):
return self.counter.most_common(n)

def elements(self):
return self.counter.elements()

def subtract(self, other):
other = _make_frozen_set_counter(other)
self.counter.subtract(other.counter)

def update(self, other):
other = _make_frozen_set_counter(other)
self.counter.update(other.counter)

def __add__(self, other):
other = _make_frozen_set_counter(other)
counter = self.counter + other.counter
return FrozenSetCounter(counter)

def __sub__(self, other):
other = _make_frozen_set_counter(other)
counter = self.counter - other.counter
return FrozenSetCounter(counter)

def __and__(self, other):
pass

def __or__(self, other):
pass
87 changes: 87 additions & 0 deletions contact_map/tests/test_set_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import itertools

# pylint: disable=wildcard-import, missing-docstring, protected-access
# pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use
# pylint: disable=wrong-import-order, unused-wildcard-import

# includes pytest
from .utils import *

from contact_map.set_dict import *
from .test_contact_map import traj

KEY_ITER_IDX = list(itertools.product(['list', 'tuple', 'fset'],
['idx', 'obj']))


def make_key(obj_type, iter_type, idx_to_type, idx_pair):
top = traj.topology
idx_to_type_f = {
'idx': lambda idx: idx,
'obj': {'atom': top.atom,
'res': top.residue}[obj_type]
}[idx_to_type]
iter_type_f = {'list': list,
'tuple': tuple,
'fset': frozenset}[iter_type]
key = iter_type_f(idx_to_type_f(idx) for idx in idx_pair)
return key


@pytest.mark.parametrize("obj_type", ['atom', 'res'])
class TestFrozenSetDict(object):
def setup(self):
topology = traj.topology
self.expected_dct = {
frozenset([0, 1]): 10,
frozenset([1, 2]): 5
}
self.atom_fsdict, self.residue_fsdct = [
FrozenSetDict({(fcn(0), fcn(1)): 10,
(fcn(1), fcn(2)): 5})
for fcn in [topology.atom, topology.residue]
]

def test_init(self, obj_type):
obj = {'atom': self.atom_fsdict,
'res': self.residue_fsdct}[obj_type]
assert obj.dct == self.expected_dct

def test_len(self, obj_type):
obj = {'atom': self.atom_fsdict,
'res': self.residue_fsdct}[obj_type]
assert len(obj) == 2

def test_iter(self, obj_type):
obj = {'atom': self.atom_fsdict,
'res': self.residue_fsdct}[obj_type]
for k in obj:
assert k in [frozenset([0, 1]), frozenset([1, 2])]

@pytest.mark.parametrize("iter_type, idx_to_type", KEY_ITER_IDX)
def test_get(self, obj_type, iter_type, idx_to_type):
obj = {'atom': self.atom_fsdict,
'res': self.residue_fsdct}[obj_type]
key = make_key(obj_type, iter_type, idx_to_type, [0, 1])
assert obj[key] == 10

@pytest.mark.parametrize("iter_type, idx_to_type", KEY_ITER_IDX)
def test_set(self, obj_type, iter_type, idx_to_type):
obj = {'atom': self.atom_fsdict,
'res': self.residue_fsdct}[obj_type]
key = make_key(obj_type, iter_type, idx_to_type, [1, 3])
obj[key] = 20
assert obj.dct[frozenset([1, 3])] == 20

@pytest.mark.parametrize("iter_type, idx_to_type", KEY_ITER_IDX)
def test_del(self, obj_type, iter_type, idx_to_type):
obj = {'atom': self.atom_fsdict,
'res': self.residue_fsdct}[obj_type]
key = make_key(obj_type, iter_type, idx_to_type, [0, 1])
del obj[key]
assert len(obj) == 1
assert list(obj.dct.keys()) == [frozenset([1, 2])]


class TestFrozenSetCounter(object):
pass