diff --git a/src/gfa.py b/src/gfa.py index b59d140..661e626 100644 --- a/src/gfa.py +++ b/src/gfa.py @@ -106,18 +106,13 @@ def save_as_gfa(self, file: str): @classmethod def from_graph(cls, graph: Graph): + """Constructs the lines of a GFA file listing paths, then sequence nodes in arbitrary order.""" gfa = gfapy.Gfa() - path_list = defaultdict(list) - segment_id = 0 - for slice in graph.slices: - for node in slice.nodes: - segment_id += 1 - gfa.add_line('\t'.join(['S', str(segment_id), node.seq])) - for path in node.paths: - path_list[path].append(segment_id) - for path_key in path_list: - path_values = [str(x) for x in path_list[path_key]] - gfa.add_line('\t'.join(['P', path_key, "+,".join(path_values)+"+", ",".join(['*' for _ in path_values])])) + for path in graph.paths.values(): + node_series = ",".join([traverse.node.id + traverse.strand for traverse in path.nodes]) + gfa.add_line('\t'.join(['P', path.accession, node_series, ",".join(['*' for _ in path.nodes])])) + for node in graph.nodes.values(): # in no particular order + gfa.add_line('\t'.join(['S', str(node.id), node.seq])) return cls(gfa) @property diff --git a/src/graph.py b/src/graph.py index b3e1021..02e3948 100644 --- a/src/graph.py +++ b/src/graph.py @@ -188,63 +188,24 @@ def __repr__(self): class Graph: - def __init__(self, paths: List = None): - """Factory for generating graphs from a representation""" - self.slices = [] # only get populated by compute_slices() + def __init__(self, paths: Iterable = None): # This can create orphan Nodes with no traversals self.nodes = keydefaultdict(lambda key: Node(key, [])) # node id = Node object if all(isinstance(x, str) for x in paths): self.paths = {x: Path(x) for x in paths} elif all(isinstance(x, Path) for x in paths): - self.paths = {path.name: path for path in paths} + self.paths = {path.accession: path for path in paths} else: self.paths = {} - #TODO: calculate slices? - - @staticmethod - def build(cmd): - """This factory uses existing slice declarations to build a graph with Paths populated in the order - that they are mentioned in the slices. Currently, this is + only and does not support non-linear - orderings. Use Path.append_node() to build non-linear graphs.""" - path_dict = keydefaultdict(lambda key: Path(key)) # construct blank path if new - slices = [] - if isinstance(cmd, str): - cmd = eval(cmd) - for sl in cmd: - current_slice = [] - if isinstance(sl, Slice): - slices.append(sl) - else: - if isinstance(sl[0], Node): # already Nodes, don't need to build - current_slice = sl - else: - try: - for i in range(0, len(sl), 2): - paths = [path_dict[key] for key in sl[i + 1]] - current_slice.append(Node(sl[i], paths)) - except IndexError: - raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2]) - - slices.append(Slice(current_slice)) - return Graph.load_from_slices(slices) - - @classmethod - def load_from_slices(cls, slices): - graph = cls([]) - graph.slices = slices - return graph def __repr__(self): """Warning: the representation strings are very sensitive to whitespace""" - return self.slices.__repr__() - - def __getitem__(self, i): - return self.slices[i] + return self.paths.__repr__() def __eq__(self, representation): if isinstance(representation, Graph): - return all(slice_a == slice_b for slice_a, slice_b in zip_longest(self.slices, representation.slices)) - return self == Graph.build(representation) # build a graph then compare it + return all(path_a == path_b for path_a, path_b in zip_longest(self.paths, representation.paths)) + raise TypeError("Graphs can only compare with other Graphs", type(representation)) def load_from_pickle(self, file: str): self = pickle.load(file) @@ -274,19 +235,83 @@ def append_node_to_path(self, node_id, strand, path_name): raise ValueError("Provide the id of the node, not", node_id) self.paths[path_name].append_node(self.nodes[node_id], strand) + def compute_slices(self): + """Alias: Upgrades a Graph to a SlicedGraph""" + return SlicedGraph.from_graph(self) + + +class SlicedGraph(Graph): + def __init__(self, paths): + super(SlicedGraph, self).__init__(paths) + """Factory for generating graphs from a representation""" + self.slices = [] # only get populated by compute_slices() + + if not self.slices: + self.compute_slices() + + def __eq__(self, representation): + if isinstance(representation, SlicedGraph): + return all(slice_a == slice_b for slice_a, slice_b in zip_longest(self.slices, representation.slices)) + return self == SlicedGraph.build(representation) # build a graph then compare it + + def __repr__(self): + """Warning: the representation strings are very sensitive to whitespace""" + return self.slices.__repr__() + + def __getitem__(self, i): + return self.slices[i] + + @staticmethod + def from_graph(graph): + g = SlicedGraph([]) + g.paths = graph.paths # shallow copy all relevant fields + g.nodes = graph.nodes + g.compute_slices() + return g + def compute_slices(self): """TODO: This is a mockup stand in for the real method.""" + if not self.paths: # nothing to do + return self first_path = next(iter(self.paths.values())) for node_traversal in first_path: node = node_traversal.node self.slices.append(Slice([node])) return self + @staticmethod + def build(cmd): + """This factory uses existing slice declarations to build a graph with Paths populated in the order + that they are mentioned in the slices. Currently, this is + only and does not support non-linear + orderings. Use Path.append_node() to build non-linear graphs.""" + if isinstance(cmd, str): + cmd = eval(cmd) + # preemptively grab all the path names from every odd list entry + paths = {key for sl in cmd for i in range(0, len(sl), 2) for key in sl[i + 1]} + graph = SlicedGraph(paths) + for sl in cmd: + current_slice = [] + if isinstance(sl, Slice): + graph.slices.append(sl) + else: + if isinstance(sl[0], Node): # already Nodes, don't need to build + current_slice = sl + else: + try: + for i in range(0, len(sl), 2): + paths = [graph.paths[key] for key in sl[i + 1]] + current_slice.append(Node(sl[i], paths)) + except IndexError: + raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2]) -# class SlicedGraph(Graph): -# def __init__(self, paths): -# super(SlicedGraph, self).__init__(paths) + graph.slices.append(Slice(current_slice)) + return graph + @classmethod + def load_from_slices(cls, slices, paths): + graph = cls(paths) + graph.slices = slices + return graph if __name__ == "__main__": diff --git a/src/test.py b/src/test.py index baa8f0e..7c5f820 100644 --- a/src/test.py +++ b/src/test.py @@ -1,7 +1,7 @@ import unittest from src.gfa import GFA from src.graph import Graph, Slice, Node, NoAnchorError, PathOverlapError, NoOverlapError, NodeMissingError, \ - Path + Path, SlicedGraph def G(rep): @@ -34,6 +34,7 @@ class GraphTest(unittest.TestCase): def example_graph(self): # IMPORTANT: Never reuse Paths: Paths must be created fresh for each graph a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e') + paths = [a, b, c, d, e] factory_input = [Slice([Node('ACGT', {a,b,c,d})]), Slice([Node('C',{a,b,d}),Node('T', {c})]), Slice([Node('GGA',{a,b,c,d})]), @@ -48,7 +49,7 @@ def example_graph(self): Slice([Node('TATA', {a, b, c, d})]) # anchor ] - base_graph = Graph.load_from_slices(factory_input) + base_graph = SlicedGraph.load_from_slices(factory_input, paths) return base_graph def test_equalities(self): @@ -57,17 +58,18 @@ def test_equalities(self): self.assertEqual(Node('A', {Path('x'),Path('y')}),Node('A', {Path('x'),Path('y')})) self.assertEqual(Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]), Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})])) - self.assertEqual(Graph.build([['ACGT', {a, b, c, d}]]), Graph.build([['ACGT', {a, b, c, d}]])) + self.assertEqual(SlicedGraph.build([['ACGT', {a, b, c, d}]]), SlicedGraph.build([['ACGT', {a, b, c, d}]])) def test_graph_factory(self): base_graph = self.example_graph() - g1, g2 = Graph.build(self.factory_input), Graph.build(self.factory_input) + g1, g2 = SlicedGraph.build(self.factory_input), SlicedGraph.build(self.factory_input) assert g1 == g2, \ ('\n' + repr(g1) + '\n' + repr(g2)) - g_double = Graph.build(eval(str(base_graph))) + g_double = SlicedGraph.build(eval(str(base_graph))) # WARN: Never compare two string literals: could be order sensitive, one object must be Graph #str(g_double) == str(base_graph) assert g_double == base_graph, repr(g_double) + '\n' + repr(base_graph) + assert g1 == base_graph, repr(g1) + '\n' + repr(base_graph) assert g_double == self.factory_input assert g_double == str(self.factory_input) @@ -96,8 +98,10 @@ def test_load_gfa_to_graph(self): self.assertEqual(len(graph.nodes), 15) def test_gfa_to_sliced_graph(self): + #TODO: this is currently close but not quite there. + # Slices must be fully defined in SlicedGraph.compute_slices() graph, gfa = self.make_graph_from_gfa() - slices = graph.compute_slices() + slices = SlicedGraph.from_graph(graph) x = 'x' y = 'y' z = 'z'