Skip to content
This repository was archived by the owner on Mar 20, 2020. It is now read-only.

Commit a4d58bd

Browse files
committed
#4 SlicedGraph now its own class, migrated methods. test_export_as_gfa now working with new Graph definition, no need for slices.
1 parent b32614f commit a4d58bd

File tree

3 files changed

+88
-64
lines changed

3 files changed

+88
-64
lines changed

src/gfa.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,13 @@ def save_as_gfa(self, file: str):
106106

107107
@classmethod
108108
def from_graph(cls, graph: Graph):
109+
"""Constructs the lines of a GFA file listing paths, then sequence nodes in arbitrary order."""
109110
gfa = gfapy.Gfa()
110-
path_list = defaultdict(list)
111-
segment_id = 0
112-
for slice in graph.slices:
113-
for node in slice.nodes:
114-
segment_id += 1
115-
gfa.add_line('\t'.join(['S', str(segment_id), node.seq]))
116-
for path in node.paths:
117-
path_list[path].append(segment_id)
118-
for path_key in path_list:
119-
path_values = [str(x) for x in path_list[path_key]]
120-
gfa.add_line('\t'.join(['P', path_key, "+,".join(path_values)+"+", ",".join(['*' for _ in path_values])]))
111+
for path in graph.paths.values():
112+
node_series = ",".join([traverse.node.id + traverse.strand for traverse in path.nodes])
113+
gfa.add_line('\t'.join(['P', path.accession, node_series, ",".join(['*' for _ in path.nodes])]))
114+
for node in graph.nodes.values(): # in no particular order
115+
gfa.add_line('\t'.join(['S', str(node.id), node.seq]))
121116
return cls(gfa)
122117

123118
@property

src/graph.py

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -188,63 +188,24 @@ def __repr__(self):
188188

189189

190190
class Graph:
191-
def __init__(self, paths: List = None):
192-
"""Factory for generating graphs from a representation"""
193-
self.slices = [] # only get populated by compute_slices()
191+
def __init__(self, paths: Iterable = None):
194192
# This can create orphan Nodes with no traversals
195193
self.nodes = keydefaultdict(lambda key: Node(key, [])) # node id = Node object
196194
if all(isinstance(x, str) for x in paths):
197195
self.paths = {x: Path(x) for x in paths}
198196
elif all(isinstance(x, Path) for x in paths):
199-
self.paths = {path.name: path for path in paths}
197+
self.paths = {path.accession: path for path in paths}
200198
else:
201199
self.paths = {}
202-
#TODO: calculate slices?
203-
204-
@staticmethod
205-
def build(cmd):
206-
"""This factory uses existing slice declarations to build a graph with Paths populated in the order
207-
that they are mentioned in the slices. Currently, this is + only and does not support non-linear
208-
orderings. Use Path.append_node() to build non-linear graphs."""
209-
path_dict = keydefaultdict(lambda key: Path(key)) # construct blank path if new
210-
slices = []
211-
if isinstance(cmd, str):
212-
cmd = eval(cmd)
213-
for sl in cmd:
214-
current_slice = []
215-
if isinstance(sl, Slice):
216-
slices.append(sl)
217-
else:
218-
if isinstance(sl[0], Node): # already Nodes, don't need to build
219-
current_slice = sl
220-
else:
221-
try:
222-
for i in range(0, len(sl), 2):
223-
paths = [path_dict[key] for key in sl[i + 1]]
224-
current_slice.append(Node(sl[i], paths))
225-
except IndexError:
226-
raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2])
227-
228-
slices.append(Slice(current_slice))
229-
return Graph.load_from_slices(slices)
230-
231-
@classmethod
232-
def load_from_slices(cls, slices):
233-
graph = cls([])
234-
graph.slices = slices
235-
return graph
236200

237201
def __repr__(self):
238202
"""Warning: the representation strings are very sensitive to whitespace"""
239-
return self.slices.__repr__()
240-
241-
def __getitem__(self, i):
242-
return self.slices[i]
203+
return self.paths.__repr__()
243204

244205
def __eq__(self, representation):
245206
if isinstance(representation, Graph):
246-
return all(slice_a == slice_b for slice_a, slice_b in zip_longest(self.slices, representation.slices))
247-
return self == Graph.build(representation) # build a graph then compare it
207+
return all(path_a == path_b for path_a, path_b in zip_longest(self.paths, representation.paths))
208+
raise TypeError("Graphs can only compare with other Graphs", type(representation))
248209

249210
def load_from_pickle(self, file: str):
250211
self = pickle.load(file)
@@ -274,19 +235,83 @@ def append_node_to_path(self, node_id, strand, path_name):
274235
raise ValueError("Provide the id of the node, not", node_id)
275236
self.paths[path_name].append_node(self.nodes[node_id], strand)
276237

238+
def compute_slices(self):
239+
"""Alias: Upgrades a Graph to a SlicedGraph"""
240+
return SlicedGraph.from_graph(self)
241+
242+
243+
class SlicedGraph(Graph):
244+
def __init__(self, paths):
245+
super(SlicedGraph, self).__init__(paths)
246+
"""Factory for generating graphs from a representation"""
247+
self.slices = [] # only get populated by compute_slices()
248+
249+
if not self.slices:
250+
self.compute_slices()
251+
252+
def __eq__(self, representation):
253+
if isinstance(representation, SlicedGraph):
254+
return all(slice_a == slice_b for slice_a, slice_b in zip_longest(self.slices, representation.slices))
255+
return self == SlicedGraph.build(representation) # build a graph then compare it
256+
257+
def __repr__(self):
258+
"""Warning: the representation strings are very sensitive to whitespace"""
259+
return self.slices.__repr__()
260+
261+
def __getitem__(self, i):
262+
return self.slices[i]
263+
264+
@staticmethod
265+
def from_graph(graph):
266+
g = SlicedGraph([])
267+
g.paths = graph.paths # shallow copy all relevant fields
268+
g.nodes = graph.nodes
269+
g.compute_slices()
270+
return g
271+
277272
def compute_slices(self):
278273
"""TODO: This is a mockup stand in for the real method."""
274+
if not self.paths: # nothing to do
275+
return self
279276
first_path = next(iter(self.paths.values()))
280277
for node_traversal in first_path:
281278
node = node_traversal.node
282279
self.slices.append(Slice([node]))
283280
return self
284281

282+
@staticmethod
283+
def build(cmd):
284+
"""This factory uses existing slice declarations to build a graph with Paths populated in the order
285+
that they are mentioned in the slices. Currently, this is + only and does not support non-linear
286+
orderings. Use Path.append_node() to build non-linear graphs."""
287+
if isinstance(cmd, str):
288+
cmd = eval(cmd)
289+
# preemptively grab all the path names from every odd list entry
290+
paths = {key for sl in cmd for i in range(0, len(sl), 2) for key in sl[i + 1]}
291+
graph = SlicedGraph(paths)
292+
for sl in cmd:
293+
current_slice = []
294+
if isinstance(sl, Slice):
295+
graph.slices.append(sl)
296+
else:
297+
if isinstance(sl[0], Node): # already Nodes, don't need to build
298+
current_slice = sl
299+
else:
300+
try:
301+
for i in range(0, len(sl), 2):
302+
paths = [graph.paths[key] for key in sl[i + 1]]
303+
current_slice.append(Node(sl[i], paths))
304+
except IndexError:
305+
raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2])
285306

286-
# class SlicedGraph(Graph):
287-
# def __init__(self, paths):
288-
# super(SlicedGraph, self).__init__(paths)
307+
graph.slices.append(Slice(current_slice))
308+
return graph
289309

310+
@classmethod
311+
def load_from_slices(cls, slices, paths):
312+
graph = cls(paths)
313+
graph.slices = slices
314+
return graph
290315

291316

292317
if __name__ == "__main__":

src/test.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from src.gfa import GFA
33
from src.graph import Graph, Slice, Node, NoAnchorError, PathOverlapError, NoOverlapError, NodeMissingError, \
4-
Path
4+
Path, SlicedGraph
55

66

77
def G(rep):
@@ -34,6 +34,7 @@ class GraphTest(unittest.TestCase):
3434
def example_graph(self):
3535
# IMPORTANT: Never reuse Paths: Paths must be created fresh for each graph
3636
a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e')
37+
paths = [a, b, c, d, e]
3738
factory_input = [Slice([Node('ACGT', {a,b,c,d})]),
3839
Slice([Node('C',{a,b,d}),Node('T', {c})]),
3940
Slice([Node('GGA',{a,b,c,d})]),
@@ -48,7 +49,7 @@ def example_graph(self):
4849
Slice([Node('TATA', {a, b, c, d})]) # anchor
4950
]
5051

51-
base_graph = Graph.load_from_slices(factory_input)
52+
base_graph = SlicedGraph.load_from_slices(factory_input, paths)
5253
return base_graph
5354

5455
def test_equalities(self):
@@ -57,17 +58,18 @@ def test_equalities(self):
5758
self.assertEqual(Node('A', {Path('x'),Path('y')}),Node('A', {Path('x'),Path('y')}))
5859
self.assertEqual(Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]),
5960
Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]))
60-
self.assertEqual(Graph.build([['ACGT', {a, b, c, d}]]), Graph.build([['ACGT', {a, b, c, d}]]))
61+
self.assertEqual(SlicedGraph.build([['ACGT', {a, b, c, d}]]), SlicedGraph.build([['ACGT', {a, b, c, d}]]))
6162

6263
def test_graph_factory(self):
6364
base_graph = self.example_graph()
64-
g1, g2 = Graph.build(self.factory_input), Graph.build(self.factory_input)
65+
g1, g2 = SlicedGraph.build(self.factory_input), SlicedGraph.build(self.factory_input)
6566
assert g1 == g2, \
6667
('\n' + repr(g1) + '\n' + repr(g2))
67-
g_double = Graph.build(eval(str(base_graph)))
68+
g_double = SlicedGraph.build(eval(str(base_graph)))
6869
# WARN: Never compare two string literals: could be order sensitive, one object must be Graph
6970
#str(g_double) == str(base_graph)
7071
assert g_double == base_graph, repr(g_double) + '\n' + repr(base_graph)
72+
assert g1 == base_graph, repr(g1) + '\n' + repr(base_graph)
7173
assert g_double == self.factory_input
7274
assert g_double == str(self.factory_input)
7375

@@ -96,8 +98,10 @@ def test_load_gfa_to_graph(self):
9698
self.assertEqual(len(graph.nodes), 15)
9799

98100
def test_gfa_to_sliced_graph(self):
101+
#TODO: this is currently close but not quite there.
102+
# Slices must be fully defined in SlicedGraph.compute_slices()
99103
graph, gfa = self.make_graph_from_gfa()
100-
slices = graph.compute_slices()
104+
slices = SlicedGraph.from_graph(graph)
101105
x = 'x'
102106
y = 'y'
103107
z = 'z'

0 commit comments

Comments
 (0)