From f7d8b2b5a16bda93af3daa005dfc472da8d537d0 Mon Sep 17 00:00:00 2001 From: Josiah Seaman Date: Tue, 9 Jul 2019 15:45:20 +0100 Subject: [PATCH] #4 Fixed Graph equality checking --- src/graph.py | 56 +++++++++++++++++++++++++++++++++++++++++----------- src/test.py | 22 +++++++++++++++++---- src/utils.py | 10 ++++++++++ 3 files changed, 73 insertions(+), 15 deletions(-) create mode 100644 src/utils.py diff --git a/src/graph.py b/src/graph.py index 9869402..fab9f61 100644 --- a/src/graph.py +++ b/src/graph.py @@ -1,8 +1,11 @@ -from typing import Callable, Iterator, Union, Optional, List, Iterable, NamedTuple +from typing import List, Iterable from itertools import zip_longest import pickle import sys +from src.utils import keydefaultdict + + class NoAnchorError(ValueError): pass class PathOverlapError(ValueError): @@ -27,7 +30,7 @@ def __len__(self): def __repr__(self): """Paths representation is sorted because set ordering is not guaranteed.""" return repr(self.seq) + \ - ', {' + ', '.join(str(i) for i in list(self.paths)) + '}' + ', {' + ', '.join(str(i) for i in sorted(list(self.paths))) + '}' def __eq__(self, other): if not isinstance(other, Node): @@ -40,8 +43,7 @@ def __hash__(self): def append_path(self, path): assert isinstance(path, Path), path - pt = PathIndex(path, len(path.nodes)) # not parallelizable - self.paths.add(pt) + self.paths.add(PathIndex(path, len(path.nodes))) # not parallelizable path.nodes.append(NodeIndex(self, len(self.paths))) def to_gfa(self, segment_id: int): @@ -129,17 +131,39 @@ def __getitem__(self, path_index): def __repr__(self): """Warning: the representation strings are very sensitive to whitespace""" - return self.nodes.__repr__() + return "'" + self.accession + "'" + + def __eq__(self, other): + return self.accession == other.accession + + def __hash__(self): + return hash(self.accession) def to_gfa(self): return '\t'.join(['P', self.accession, "+,".join([x.node.name + x.strand for x in self.nodes]) + "+", ",".join(['*' for x in self.nodes])]) -class PathIndex(NamedTuple): +class PathIndex: """Link from a Node to the place in the path where the Node is referenced. A Node can appear in a Path multiple times. Index indicates which instance it is.""" - path: Path - index: int + def __init__(self, path: Path, index: int): + self.path = path + self.index = index + + def __repr__(self): + return repr(self.path.accession) + + def __eq__(self, other): + if self.path.accession == other.path.accession and self.index == other.index: + return True + else: + return False + + def __lt__(self, other): + return self.path.accession < other.path.accession + + def __hash__(self): + return hash(self.path.accession) * (self.index if self.index else 1) class NodeIndex: @@ -149,16 +173,25 @@ def __init__(self, node: Node, index: int, strand: str = '+'): self.index = index self.strand = strand # TODO: make this required + def __repr__(self): + return self.node.seq + class Graph: - def __init__(self, paths: List[Path] = None): + def __init__(self, paths: List = None): """Factory for generating graphs from a representation""" self.slices = [] - self.paths = paths if paths else [] # can't be in the signature + if all(isinstance(x, str) for x in paths): + self.paths = [Path(x) for x in paths] + elif all(isinstance(x, Path) for x in paths): + self.paths = paths + else: + self.paths = [] #TODO: calculate slices? @staticmethod def build(cmd): + path_dict = keydefaultdict(lambda key: Path(key)) # construct blank path if new slices = [] if isinstance(cmd, str): cmd = eval(cmd) @@ -172,7 +205,8 @@ def build(cmd): else: try: for i in range(0, len(sl), 2): - current_slice.append(Node(sl[i], sl[i + 1])) + paths = [path_dict[key] for key in sl[i + 1]] + current_slice.append(Node(sl[i], paths)) except IndexError: raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2]) diff --git a/src/test.py b/src/test.py index d89cc91..29da757 100644 --- a/src/test.py +++ b/src/test.py @@ -11,8 +11,12 @@ def G(rep): return Graph.build(rep)[0] +a, b, c, d, e = 'a', 'b', 'c', 'd', 'e' # Paths must be created first class GraphTest(unittest.TestCase): - a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e') # Paths must be created first + """Constructing a node with an existing Path object will modify that Path object (doubly linked) + which means care must be taken when constructing Graphs. From factory_input we have an example of + pure python to Graph.build in one step. In example_graph, we must first declare the Paths, + then reference them in order in Node Constructors. Order matters for Graph identity!""" # Path e is sometimes introduced as a tie breaker for Slice.secondary() factory_input = [['ACGT', {a, b, c, d}], ['C', {a, b, d}, 'T', {c}], # SNP @@ -28,7 +32,8 @@ class GraphTest(unittest.TestCase): ['TATA', {a, b, c, d}]] # [11] anchor def example_graph(self): - a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e') # Paths must be created first + # IMPORTANT: Never reuse Paths: Paths must be created fresh for each graph + a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e') factory_input = [Slice([Node('ACGT', {a,b,c,d})]), Slice([Node('C',{a,b,d}),Node('T', {c})]), Slice([Node('GGA',{a,b,c,d})]), @@ -46,10 +51,19 @@ def example_graph(self): base_graph = Graph.load_from_slices(factory_input) return base_graph + def test_equalities(self): + self.assertEqual(Node('A', {}),Node('A', {})) + self.assertEqual(Node('A', {Path('x')}),Node('A', {Path('x')})) + self.assertEqual(Node('A', {Path('x'),Path('y')}),Node('A', {Path('x'),Path('y')})) + self.assertEqual(Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]), + Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})])) + self.assertEqual(Graph.build([['ACGT', {a, b, c, d}]]), Graph.build([['ACGT', {a, b, c, d}]])) + def test_graph_factory(self): base_graph = self.example_graph() - assert base_graph == Graph.build(self.factory_input), \ - ('\n' + repr(base_graph) + '\n' + str(self.factory_input)) + g1, g2 = Graph.build(self.factory_input), Graph.build(self.factory_input) + assert g1 == g2, \ + ('\n' + repr(g1) + '\n' + repr(g2)) g_double = Graph.build(eval(str(base_graph))) # WARN: Never compare two string literals: could be order sensitive, one object must be Graph #str(g_double) == str(base_graph) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..df3e172 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,10 @@ +from collections import defaultdict + + +class keydefaultdict(defaultdict): + def __missing__(self, key): + if self.default_factory is None: + raise KeyError( key ) + else: + ret = self[key] = self.default_factory(key) + return ret \ No newline at end of file