Skip to content
This repository has been archived by the owner on Mar 20, 2020. It is now read-only.

Commit

Permalink
#4 SlicedGraph now its own class, migrated methods. test_export_as_gf…
Browse files Browse the repository at this point in the history
…a now working with new Graph definition, no need for slices.
  • Loading branch information
josiahseaman committed Jul 10, 2019
1 parent b32614f commit a4d58bd
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 64 deletions.
17 changes: 6 additions & 11 deletions src/gfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,13 @@ def save_as_gfa(self, file: str):

@classmethod
def from_graph(cls, graph: Graph):
"""Constructs the lines of a GFA file listing paths, then sequence nodes in arbitrary order."""
gfa = gfapy.Gfa()
path_list = defaultdict(list)
segment_id = 0
for slice in graph.slices:
for node in slice.nodes:
segment_id += 1
gfa.add_line('\t'.join(['S', str(segment_id), node.seq]))
for path in node.paths:
path_list[path].append(segment_id)
for path_key in path_list:
path_values = [str(x) for x in path_list[path_key]]
gfa.add_line('\t'.join(['P', path_key, "+,".join(path_values)+"+", ",".join(['*' for _ in path_values])]))
for path in graph.paths.values():
node_series = ",".join([traverse.node.id + traverse.strand for traverse in path.nodes])
gfa.add_line('\t'.join(['P', path.accession, node_series, ",".join(['*' for _ in path.nodes])]))
for node in graph.nodes.values(): # in no particular order
gfa.add_line('\t'.join(['S', str(node.id), node.seq]))
return cls(gfa)

@property
Expand Down
119 changes: 72 additions & 47 deletions src/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,63 +188,24 @@ def __repr__(self):


class Graph:
def __init__(self, paths: List = None):
"""Factory for generating graphs from a representation"""
self.slices = [] # only get populated by compute_slices()
def __init__(self, paths: Iterable = None):
# This can create orphan Nodes with no traversals
self.nodes = keydefaultdict(lambda key: Node(key, [])) # node id = Node object
if all(isinstance(x, str) for x in paths):
self.paths = {x: Path(x) for x in paths}
elif all(isinstance(x, Path) for x in paths):
self.paths = {path.name: path for path in paths}
self.paths = {path.accession: path for path in paths}
else:
self.paths = {}
#TODO: calculate slices?

@staticmethod
def build(cmd):
"""This factory uses existing slice declarations to build a graph with Paths populated in the order
that they are mentioned in the slices. Currently, this is + only and does not support non-linear
orderings. Use Path.append_node() to build non-linear graphs."""
path_dict = keydefaultdict(lambda key: Path(key)) # construct blank path if new
slices = []
if isinstance(cmd, str):
cmd = eval(cmd)
for sl in cmd:
current_slice = []
if isinstance(sl, Slice):
slices.append(sl)
else:
if isinstance(sl[0], Node): # already Nodes, don't need to build
current_slice = sl
else:
try:
for i in range(0, len(sl), 2):
paths = [path_dict[key] for key in sl[i + 1]]
current_slice.append(Node(sl[i], paths))
except IndexError:
raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2])

slices.append(Slice(current_slice))
return Graph.load_from_slices(slices)

@classmethod
def load_from_slices(cls, slices):
graph = cls([])
graph.slices = slices
return graph

def __repr__(self):
"""Warning: the representation strings are very sensitive to whitespace"""
return self.slices.__repr__()

def __getitem__(self, i):
return self.slices[i]
return self.paths.__repr__()

def __eq__(self, representation):
if isinstance(representation, Graph):
return all(slice_a == slice_b for slice_a, slice_b in zip_longest(self.slices, representation.slices))
return self == Graph.build(representation) # build a graph then compare it
return all(path_a == path_b for path_a, path_b in zip_longest(self.paths, representation.paths))
raise TypeError("Graphs can only compare with other Graphs", type(representation))

def load_from_pickle(self, file: str):
self = pickle.load(file)
Expand Down Expand Up @@ -274,19 +235,83 @@ def append_node_to_path(self, node_id, strand, path_name):
raise ValueError("Provide the id of the node, not", node_id)
self.paths[path_name].append_node(self.nodes[node_id], strand)

def compute_slices(self):
"""Alias: Upgrades a Graph to a SlicedGraph"""
return SlicedGraph.from_graph(self)


class SlicedGraph(Graph):
def __init__(self, paths):
super(SlicedGraph, self).__init__(paths)
"""Factory for generating graphs from a representation"""
self.slices = [] # only get populated by compute_slices()

if not self.slices:
self.compute_slices()

def __eq__(self, representation):
if isinstance(representation, SlicedGraph):
return all(slice_a == slice_b for slice_a, slice_b in zip_longest(self.slices, representation.slices))
return self == SlicedGraph.build(representation) # build a graph then compare it

def __repr__(self):
"""Warning: the representation strings are very sensitive to whitespace"""
return self.slices.__repr__()

def __getitem__(self, i):
return self.slices[i]

@staticmethod
def from_graph(graph):
g = SlicedGraph([])
g.paths = graph.paths # shallow copy all relevant fields
g.nodes = graph.nodes
g.compute_slices()
return g

def compute_slices(self):
"""TODO: This is a mockup stand in for the real method."""
if not self.paths: # nothing to do
return self
first_path = next(iter(self.paths.values()))
for node_traversal in first_path:
node = node_traversal.node
self.slices.append(Slice([node]))
return self

@staticmethod
def build(cmd):
"""This factory uses existing slice declarations to build a graph with Paths populated in the order
that they are mentioned in the slices. Currently, this is + only and does not support non-linear
orderings. Use Path.append_node() to build non-linear graphs."""
if isinstance(cmd, str):
cmd = eval(cmd)
# preemptively grab all the path names from every odd list entry
paths = {key for sl in cmd for i in range(0, len(sl), 2) for key in sl[i + 1]}
graph = SlicedGraph(paths)
for sl in cmd:
current_slice = []
if isinstance(sl, Slice):
graph.slices.append(sl)
else:
if isinstance(sl[0], Node): # already Nodes, don't need to build
current_slice = sl
else:
try:
for i in range(0, len(sl), 2):
paths = [graph.paths[key] for key in sl[i + 1]]
current_slice.append(Node(sl[i], paths))
except IndexError:
raise IndexError("Expecting two terms: ", sl[0]) # sl[i:i+2])

# class SlicedGraph(Graph):
# def __init__(self, paths):
# super(SlicedGraph, self).__init__(paths)
graph.slices.append(Slice(current_slice))
return graph

@classmethod
def load_from_slices(cls, slices, paths):
graph = cls(paths)
graph.slices = slices
return graph


if __name__ == "__main__":
Expand Down
16 changes: 10 additions & 6 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from src.gfa import GFA
from src.graph import Graph, Slice, Node, NoAnchorError, PathOverlapError, NoOverlapError, NodeMissingError, \
Path
Path, SlicedGraph


def G(rep):
Expand Down Expand Up @@ -34,6 +34,7 @@ class GraphTest(unittest.TestCase):
def example_graph(self):
# IMPORTANT: Never reuse Paths: Paths must be created fresh for each graph
a, b, c, d, e = Path('a'), Path('b'), Path('c'), Path('d'), Path('e')
paths = [a, b, c, d, e]
factory_input = [Slice([Node('ACGT', {a,b,c,d})]),
Slice([Node('C',{a,b,d}),Node('T', {c})]),
Slice([Node('GGA',{a,b,c,d})]),
Expand All @@ -48,7 +49,7 @@ def example_graph(self):
Slice([Node('TATA', {a, b, c, d})]) # anchor
]

base_graph = Graph.load_from_slices(factory_input)
base_graph = SlicedGraph.load_from_slices(factory_input, paths)
return base_graph

def test_equalities(self):
Expand All @@ -57,17 +58,18 @@ def test_equalities(self):
self.assertEqual(Node('A', {Path('x'),Path('y')}),Node('A', {Path('x'),Path('y')}))
self.assertEqual(Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]),
Slice([Node('ACGT', {Path('a'), Path('b'), Path('c'), Path('d')})]))
self.assertEqual(Graph.build([['ACGT', {a, b, c, d}]]), Graph.build([['ACGT', {a, b, c, d}]]))
self.assertEqual(SlicedGraph.build([['ACGT', {a, b, c, d}]]), SlicedGraph.build([['ACGT', {a, b, c, d}]]))

def test_graph_factory(self):
base_graph = self.example_graph()
g1, g2 = Graph.build(self.factory_input), Graph.build(self.factory_input)
g1, g2 = SlicedGraph.build(self.factory_input), SlicedGraph.build(self.factory_input)
assert g1 == g2, \
('\n' + repr(g1) + '\n' + repr(g2))
g_double = Graph.build(eval(str(base_graph)))
g_double = SlicedGraph.build(eval(str(base_graph)))
# WARN: Never compare two string literals: could be order sensitive, one object must be Graph
#str(g_double) == str(base_graph)
assert g_double == base_graph, repr(g_double) + '\n' + repr(base_graph)
assert g1 == base_graph, repr(g1) + '\n' + repr(base_graph)
assert g_double == self.factory_input
assert g_double == str(self.factory_input)

Expand Down Expand Up @@ -96,8 +98,10 @@ def test_load_gfa_to_graph(self):
self.assertEqual(len(graph.nodes), 15)

def test_gfa_to_sliced_graph(self):
#TODO: this is currently close but not quite there.
# Slices must be fully defined in SlicedGraph.compute_slices()
graph, gfa = self.make_graph_from_gfa()
slices = graph.compute_slices()
slices = SlicedGraph.from_graph(graph)
x = 'x'
y = 'y'
z = 'z'
Expand Down

0 comments on commit a4d58bd

Please sign in to comment.