diff --git a/src/graph.py b/src/graph.py index 31e22e9..ffda621 100644 --- a/src/graph.py +++ b/src/graph.py @@ -251,9 +251,6 @@ def compute_slices(self): return SlicedGraph.from_graph(self) -from sort import DAGify - - class SlicedGraph(Graph): def __init__(self, paths): super(SlicedGraph, self).__init__(paths) @@ -295,6 +292,8 @@ def compute_slices(self): def compute_slices_by_dagify(self): """This method uses DAGify algorithm to compute slices.""" + from src.sort import DAGify # help avoid circular import + if not self.paths: return self dagify = DAGify(self.paths) diff --git a/src/sort.py b/src/sort.py index 2636e2a..946c42e 100644 --- a/src/sort.py +++ b/src/sort.py @@ -1,6 +1,9 @@ -from src.graph import * - +import sys import dataclasses +from typing import List + +from src.graph import NodeTraversal, Path, Slice, Node, SlicedGraph + @dataclasses.dataclass class Profile: @@ -13,11 +16,12 @@ def __repr__(self): return "["+str(self.node.node) + str(self.paths)+":"+str(self.candidate_paths) +"]" class DAGify: - def __init__(self, paths: List[Path], nodes={}): + def __init__(self, paths: List[Path], nodes=None): """ - :type paths: List[Path] """ + if nodes is None: + nodes = {} self.paths = paths self.nodes = nodes @@ -102,7 +106,7 @@ def lcs(self, s1: List[Profile], s2: Path) -> List[Profile]: return index - def to_slices(self, profile: List[Profile]) -> List[Path]: + def to_slices(self, profile: List[Profile]) -> List[Slice]: factory_input = [] current_slice = Slice([]) current_paths = [] @@ -138,11 +142,12 @@ def to_slices(self, profile: List[Profile]) -> List[Path]: if len(current_slice.nodes) > 0: all_path_set = set([x for x in current_paths]) if profile[-1].candidate_paths - all_path_set != set(): + print(prof) current_slice.add_node(Node("", prof.candidate_paths - all_path_set)) factory_input.append(current_slice) return factory_input - def to_graph(self, profile: List[Profile]): - factory_input = self.to_slices(profile) + def to_graph(self, profiles: List[Profile]): + factory_input = self.to_slices(profiles) base_graph = SlicedGraph.load_from_slices(factory_input, self.paths) return base_graph diff --git a/src/test.py b/src/test.py index 77e958a..ec30b6c 100644 --- a/src/test.py +++ b/src/test.py @@ -219,8 +219,6 @@ def test_load_gfa_to_graph(self): self.assertEqual(len(graph.nodes), 15) def test_gfa_to_sliced_graph(self): - #TODO: this is currently close but not quite there. - # Slices must be fully defined in SlicedGraph.compute_slices() graph, gfa = self.make_graph_from_gfa() slices = SlicedGraph.from_graph(graph) x = 'x'