Skip to content
This repository has been archived by the owner on Mar 20, 2020. It is now read-only.

Commit

Permalink
Merge pull request #9 from graph-genome/dagify
Browse files Browse the repository at this point in the history
#4: Add DAGify method for linearizing the order of nodes and creating SlicedGraph using recursive longest common strings.
  • Loading branch information
josiahseaman authored Jul 19, 2019
2 parents 51bb237 + 0dc0528 commit 8ea35d9
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 64 deletions.
74 changes: 24 additions & 50 deletions src/gfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,35 @@ def save_as_gfa(self, file: str):
def from_graph(cls, graph: Graph):
"""Constructs the lines of a GFA file listing paths, then sequence nodes in arbitrary order."""
gfa = gfapy.Gfa()
for path in graph.paths.values():
for path in graph.paths:
node_series = ",".join([traverse.node.id + traverse.strand for traverse in path.nodes])
gfa.add_line('\t'.join(['P', path.accession, node_series, ",".join(['*' for _ in path.nodes])]))
for node in graph.nodes.values(): # in no particular order
gfa.add_line('\t'.join(['S', str(node.id), node.seq]))
return cls(gfa)

@property
def to_paths(self) -> List[Path]:
node_hash = {}
for segment in self.gfa.segments:
node_id = segment.name + "+"
node = Node(segment.sequence, [])
node_hash[node_id] = node

node_id = segment.name + "-"
node = Node(segment.sequence, [])
node_hash[node_id] = node

paths = []
for path in self.gfa.paths:
nodes = []
for node in path.segment_names:
node_index = NodeTraversal(Node(node_hash[node.name + node.orient].seq, [], node.name), node.orient)
nodes.append(node_index)
paths.append(Path(path.name, nodes))

return paths

@property
def to_graph(self):
# Extract all paths into graph
Expand All @@ -125,58 +147,10 @@ def to_graph(self):
graph.append_node_to_path(node.name, node.orient, path.name)
for segment in self.gfa.segments:
graph.nodes[segment.name].seq = segment.sequence
graph.paths = self.to_paths
return graph
# IMPORTANT: It's not clear to Josiah how much of the below is necessary, so it's being left unmodified.

topological_sort_helper = TopologicalSort()
path_dict = defaultdict(list)
node_hash = {}

# Extract all paths into graph
for path in self.gfa.paths:
for node in path.segment_names:
path_dict[node.name + node.orient].append(path.name)
for node_pair in pairwise(path.segment_names):
topological_sort_helper.add_edge(
node_pair[0].name + node_pair[0].orient,
node_pair[1].name + node_pair[1].orient)

# Extract all nodes in the graph.
for segment in self.gfa.segments:
node_id = segment.name + "+"
node = Node(segment.sequence, path_dict[node_id])
node_hash[node_id] = node

node_id = segment.name + "-"
node = Node(segment.sequence, path_dict[node_id])
node_hash[node_id] = node

node_stack = topological_sort_helper.topologicalSort()

# Cluster nodes as multiple slices according to the result of the topological sort.
factory_input = []
current_slice = Slice([])
for node in node_stack:
if len(path_dict[node]) == len(self.gfa.paths):
if len(current_slice.nodes) > 0:
factory_input.append(current_slice)
factory_input.append(Slice([node_hash[node]]))
current_slice = Slice([])
else:
all_set = set()
for items in [x.paths for x in current_slice.nodes]:
all_set = all_set | items
if set(path_dict[node]) & all_set != set():
if len(current_slice.nodes) > 0:
current_slice.add_node(Node("", set([x.name for x in self.gfa.paths]) - all_set))
factory_input.append(current_slice)
current_slice = Slice([node_hash[node]])
else:
current_slice.add_node(node_hash[node])

base_graph = Graph.load_from_slices(factory_input)
return base_graph


'''
class XGWrapper:
Expand Down
36 changes: 30 additions & 6 deletions src/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ class Path:
was sequenced. A path visits a series of nodes and the ordered concatenation of the node
sequences is the accession's genome. Create Paths first from accession names, then append
them to Nodes to link together."""
def __init__(self, accession: str):
def __init__(self, accession: str, nodes = []):
self.accession = accession # one path per accessions
self.nodes = [] # List[NodeTraversal]
self.nodes = nodes # List[NodeTraversal]
self.position_checkpoints = {} # TODO: currently not used

def __getitem__(self, path_index):
Expand All @@ -150,6 +150,9 @@ def append_node(self, node: Node, strand: str):
node.paths.add(PathIndex(self, len(self.nodes)-1)) # already appended node
return node

def name(self):
return self.accession

def to_gfa(self):
return '\t'.join(['P', self.accession, "+,".join([x.node.name + x.strand for x in self.nodes]) + "+", ",".join(['*' for x in self.nodes])])

Expand All @@ -165,7 +168,7 @@ def __repr__(self):
return repr(self.path.accession)

def __eq__(self, other):
if self.path.accession == other.path.accession and self.index == other.index:
if self.path.accession == other.path.accession: # and self.index == other.index:
return True
else:
return False
Expand All @@ -174,7 +177,7 @@ def __lt__(self, other):
return self.path.accession < other.path.accession

def __hash__(self):
return hash(self.path.accession) * (self.index if self.index else 1)
return hash(self.path.accession) # * (self.index if self.index else 1)


class NodeTraversal:
Expand All @@ -184,7 +187,15 @@ def __init__(self, node: Node, strand: str = '+'):
self.strand = strand # TODO: make this required

def __repr__(self):
return self.node.seq
if self.strand == '+':
return self.node.seq
else:
complement = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}
return "".join(complement.get(base, base) for base in reversed(self.node.seq))


def __eq__(self, other):
return self.node.id == other.node.id and self.strand == other.strand


class Graph:
Expand Down Expand Up @@ -266,7 +277,7 @@ def from_graph(graph):
g = SlicedGraph([])
g.paths = graph.paths # shallow copy all relevant fields
g.nodes = graph.nodes
g.compute_slices()
g.compute_slices_by_dagify()
return g

def compute_slices(self):
Expand All @@ -279,6 +290,18 @@ def compute_slices(self):
self.slices.append(Slice([node]))
return self

def compute_slices_by_dagify(self):
"""This method uses DAGify algorithm to compute slices."""
from src.sort import DAGify # help avoid circular import

if not self.paths:
return self
dagify = DAGify(self.paths)
profile = dagify.recursive_merge(0)
slices = dagify.to_slices(profile)
self.slices = slices
return self

@staticmethod
def build(cmd):
"""This factory uses existing slice declarations to build a graph with Paths populated in the order
Expand All @@ -289,6 +312,7 @@ def build(cmd):
# preemptively grab all the path names from every odd list entry
paths = {key for sl in cmd for i in range(0, len(sl), 2) for key in sl[i + 1]}
graph = SlicedGraph(paths)
graph.slices = []
for sl in cmd:
current_slice = []
if isinstance(sl, Slice):
Expand Down
153 changes: 153 additions & 0 deletions src/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import sys
import dataclasses
from typing import List

from src.graph import NodeTraversal, Path, Slice, Node, SlicedGraph


@dataclasses.dataclass
class Profile:
node: NodeTraversal
paths: List[Path]
candidate_paths: set()
duplicate: bool = False

def __repr__(self):
return "["+str(self.node.node) + str(self.paths)+":"+str(self.candidate_paths) +"]"

class DAGify:
def __init__(self, paths: List[Path], nodes=None):
"""
:type paths: List[Path]
"""
if nodes is None:
nodes = {}
self.paths = paths
self.nodes = nodes

def search_for_minimizing_replications(self) -> (List[Profile], int):
min_rep = sys.maxsize
profile = []
for i, _ in enumerate(self.paths):
profile_candidate = self.recursive_merge(i)
if min_rep > len([x.duplicate for x in profile_candidate if x.duplicate]):
min_rep = len([x.duplicate for x in profile_candidate if x.duplicate])
profile = profile_candidate
return profile, min_rep

def recursive_merge(self, primary_path_index: int = 0) -> List[Profile]:
profile = []
for node_index in self.paths[primary_path_index].nodes:
profile.append(Profile(node_index, [self.paths[primary_path_index]], {self.paths[primary_path_index]}, False))
for i, path in enumerate(self.paths):
if i == primary_path_index:
continue
profile = self.lcs(profile, path)
return profile

def lcs(self, s1: List[Profile], s2: Path) -> List[Profile]:
n, m = len(s1), len(s2.nodes)
dp = [[0] * (m+1) for _ in range(n+1)]

for i in range(1, n + 1):
for j in range(1, m + 1):
if s1[i-1].node == s2.nodes[j-1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
i, j = n, m
index = []
prev = set()
candidate_path_flag = False

while i > 0 and j > 0:
if s1[i-1].node == s2.nodes[j-1]:
prev_paths = s1[i-1].paths
prev_paths.append(s2)
candidate_paths = s1[i-1].candidate_paths
candidate_paths.add(s2)
candidate_path_flag = True

index.append(Profile(s1[i-1].node, prev_paths, candidate_paths, s1[i-1].node.node.id in prev))
prev.add(s1[i-1].node.node.id)
i -= 1
j -= 1
elif dp[i-1][j] > dp[i][j-1]:
prev_paths = s1[i-1].paths
candidate_paths = s1[i-1].candidate_paths
if candidate_path_flag:
candidate_paths.add(s2)
index.append(Profile(s1[i-1].node, prev_paths, candidate_paths, s1[i-1].node.node.id in prev))
prev.add(s1[i-1].node.node.id)
i -= 1
else:
candidate_paths = {s2}
if i > n and s1[i]:
candidate_paths |= s1[i].candidate_paths
if s1[i-1]:
candidate_paths |= s1[i-1].candidate_paths
index.append(Profile(s2.nodes[j-1], [s2], candidate_paths, s2.nodes[j-1].node.id in prev))
prev.add(s2.nodes[j-1].node.id)
j -= 1

while i > 0:
prev_paths = s1[i - 1].paths
prev_candidates = s1[i-1].candidate_paths
index.append(Profile(s1[i - 1].node, prev_paths, prev_candidates, s1[i - 1].node.node.id in prev))
prev.add(s1[i - 1].node.node.id)
i -= 1

while j > 0:
prev.add(s2.nodes[j - 1].node.id)
index.append(Profile(s2.nodes[j - 1], [s2], {s2}, False))
j -= 1

index.reverse()

return index

def to_slices(self, profile: List[Profile]) -> List[Slice]:
factory_input = []
current_slice = Slice([])
current_paths = []

for index, prof in enumerate(profile):
paths = [x for x in prof.paths]
all_path_set = set([x for x in current_paths])
# print(prof, current_slice, current_paths)
candidate_paths_set = prof.candidate_paths
if index + 1 != len(profile):
candidate_paths_set |= profile[index+1].candidate_paths

if len(prof.paths) == len(candidate_paths_set):
if len(current_slice.nodes) > 0:
if prof.candidate_paths - all_path_set != set():
current_slice.add_node(Node("", prof.candidate_paths - all_path_set))
factory_input.append(current_slice)
factory_input.append(Slice([Node(prof.node.node.seq, paths, prof.node.node.id)]))
current_slice = Slice([])
current_paths = []
else:
if set([x for x in prof.paths]) & all_path_set != set():
if len(current_slice.nodes) > 0:
if prof.candidate_paths - all_path_set != set():
current_slice.add_node(Node("", prof.candidate_paths - all_path_set))
factory_input.append(current_slice)
current_slice = Slice([Node(prof.node.node.seq, paths, prof.node.node.id)])
current_paths = paths
else:
current_slice.add_node(Node(prof.node.node.seq, paths, prof.node.node.id))
current_paths.extend(paths)

if len(current_slice.nodes) > 0:
all_path_set = set([x for x in current_paths])
if profile[-1].candidate_paths - all_path_set != set():
print(prof)
current_slice.add_node(Node("", prof.candidate_paths - all_path_set))
factory_input.append(current_slice)
return factory_input

def to_graph(self, profiles: List[Profile]):
factory_input = self.to_slices(profiles)
base_graph = SlicedGraph.load_from_slices(factory_input, self.paths)
return base_graph
Loading

0 comments on commit 8ea35d9

Please sign in to comment.