Skip to content
This repository has been archived by the owner on Mar 20, 2020. It is now read-only.

Commit

Permalink
#4 Fixed Graph equality checking
Browse files Browse the repository at this point in the history
  • Loading branch information
josiahseaman committed Jul 9, 2019
1 parent 0a7d73b commit f7d8b2b
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 15 deletions.
56 changes: 45 additions & 11 deletions src/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Callable, Iterator, Union, Optional, List, Iterable, NamedTuple
from typing import List, Iterable
from itertools import zip_longest
import pickle
import sys

from src.utils import keydefaultdict


class NoAnchorError(ValueError):
pass
class PathOverlapError(ValueError):
Expand All @@ -27,7 +30,7 @@ def __len__(self):
def __repr__(self):
"""Paths representation is sorted because set ordering is not guaranteed."""
return repr(self.seq) + \
', {' + ', '.join(str(i) for i in list(self.paths)) + '}'
', {' + ', '.join(str(i) for i in sorted(list(self.paths))) + '}'

def __eq__(self, other):
if not isinstance(other, Node):
Expand All @@ -40,8 +43,7 @@ def __hash__(self):

def append_path(self, path):
assert isinstance(path, Path), path
pt = PathIndex(path, len(path.nodes)) # not parallelizable
self.paths.add(pt)
self.paths.add(PathIndex(path, len(path.nodes))) # not parallelizable
path.nodes.append(NodeIndex(self, len(self.paths)))

def to_gfa(self, segment_id: int):
Expand Down Expand Up @@ -129,17 +131,39 @@ def __getitem__(self, path_index):

def __repr__(self):
"""Warning: the representation strings are very sensitive to whitespace"""
return self.nodes.__repr__()
return "'" + self.accession + "'"

def __eq__(self, other):
return self.accession == other.accession

def __hash__(self):
return hash(self.accession)

def to_gfa(self):
return '\t'.join(['P', self.accession, "+,".join([x.node.name + x.strand for x in self.nodes]) + "+", ",".join(['*' for x in self.nodes])])


class PathIndex(NamedTuple):
class PathIndex:
"""Link from a Node to the place in the path where the Node is referenced. A Node can appear
in a Path multiple times. Index indicates which instance it is."""
path: Path
index: int
def __init__(self, path: Path, index: int):
self.path = path
self.index = index

def __repr__(self):
return repr(self.path.accession)

def __eq__(self, other):
if self.path.accession == other.path.accession and self.index == other.index:
return True
else:
return False

def __lt__(self, other):
return self.path.accession < other.path.accession

def __hash__(self):
return hash(self.path.accession) * (self.index if self.index else 1)


class NodeIndex:
Expand All @@ -149,16 +173,25 @@ def __init__(self, node: Node, index: int, strand: str = '+'):
self.index = index
self.strand = strand # TODO: make this required

def __repr__(self):
return self.node.seq


class Graph:
def __init__(self, paths: List[Path] = None):
def __init__(self, paths: List = None):
"""Factory for generating graphs from a representation"""
self.slices = []
self.paths = paths if paths else [] # can't be in the signature
if all(isinstance(x, str) for x in paths):
self.paths = [Path(x) for x in paths]
elif all(isinstance(x, Path) for x in paths):
self.paths = paths
else:
self.paths = []
#TODO: calculate slices?

@staticmethod
def build(cmd):
path_dict = keydefaultdict(lambda key: Path(key)) # construct blank path if new
slices = []
if isinstance(cmd, str):
cmd = eval(cmd)
Expand All @@ -172,7 +205,8 @@ def build(cmd):
else:
try:
for i in range(0, len(sl), 2):
current_slice.append(Node(sl[i], sl[i + 1]))
paths = [path_dict[key] for key in sl[i + 1]]
current_slice.append(Node(sl[i], paths))
except IndexError:
raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2])

Expand Down
22 changes: 18 additions & 4 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ def G(rep):
return Graph.build(rep)[0]


a, b, c, d, e = 'a', 'b', 'c', 'd', 'e' # Paths must be created first
class GraphTest(unittest.TestCase):
a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e') # Paths must be created first
"""Constructing a node with an existing Path object will modify that Path object (doubly linked)
which means care must be taken when constructing Graphs. From factory_input we have an example of
pure python to Graph.build in one step. In example_graph, we must first declare the Paths,
then reference them in order in Node Constructors. Order matters for Graph identity!"""
# Path e is sometimes introduced as a tie breaker for Slice.secondary()
factory_input = [['ACGT', {a, b, c, d}],
['C', {a, b, d}, 'T', {c}], # SNP
Expand All @@ -28,7 +32,8 @@ class GraphTest(unittest.TestCase):
['TATA', {a, b, c, d}]] # [11] anchor

def example_graph(self):
a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e') # Paths must be created first
# IMPORTANT: Never reuse Paths: Paths must be created fresh for each graph
a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e')
factory_input = [Slice([Node('ACGT', {a,b,c,d})]),
Slice([Node('C',{a,b,d}),Node('T', {c})]),
Slice([Node('GGA',{a,b,c,d})]),
Expand All @@ -46,10 +51,19 @@ def example_graph(self):
base_graph = Graph.load_from_slices(factory_input)
return base_graph

def test_equalities(self):
self.assertEqual(Node('A', {}),Node('A', {}))
self.assertEqual(Node('A', {Path('x')}),Node('A', {Path('x')}))
self.assertEqual(Node('A', {Path('x'),Path('y')}),Node('A', {Path('x'),Path('y')}))
self.assertEqual(Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]),
Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]))
self.assertEqual(Graph.build([['ACGT', {a, b, c, d}]]), Graph.build([['ACGT', {a, b, c, d}]]))

def test_graph_factory(self):
base_graph = self.example_graph()
assert base_graph == Graph.build(self.factory_input), \
('\n' + repr(base_graph) + '\n' + str(self.factory_input))
g1, g2 = Graph.build(self.factory_input), Graph.build(self.factory_input)
assert g1 == g2, \
('\n' + repr(g1) + '\n' + repr(g2))
g_double = Graph.build(eval(str(base_graph)))
# WARN: Never compare two string literals: could be order sensitive, one object must be Graph
#str(g_double) == str(base_graph)
Expand Down
10 changes: 10 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from collections import defaultdict


class keydefaultdict(defaultdict):
def __missing__(self, key):
if self.default_factory is None:
raise KeyError( key )
else:
ret = self[key] = self.default_factory(key)
return ret

0 comments on commit f7d8b2b

Please sign in to comment.