Skip to content

Commit 69bb8c7

Browse files
hyanwongmergify[bot]
authored andcommitted
Move stabilise_node_ordering into eval function
Fixes #709. I guess we don't need to add this to the changelog as it was never documented in the first place?
1 parent 4446741 commit 69bb8c7

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

tsinfer/eval_util.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,6 @@ def run_perfect_inference(
710710
extended_checks=extended_checks,
711711
progress_monitor=progress_monitor,
712712
)
713-
# If time_chunking is turned on we need to stabilise the node ordering in the output
714-
# to ensure that the node IDs are comparable.
715713
inferred_ts = inference.match_samples(
716714
sample_data,
717715
ancestors_ts,
@@ -720,13 +718,34 @@ def run_perfect_inference(
720718
num_threads=num_threads,
721719
extended_checks=extended_checks,
722720
progress_monitor=progress_monitor,
723-
stabilise_node_ordering=time_chunking and not path_compression,
721+
simplify=False, # Don't simplify until we have stabilised the node order below
724722
)
723+
# If time_chunking is turned on we need to stabilise the node ordering in the output
724+
# to ensure that the node IDs are comparable.
725+
if time_chunking and not path_compression:
726+
inferred_ts = stabilise_node_ordering(inferred_ts)
727+
725728
# to compare against the original, we need to remove unary nodes from the inferred TS
726729
inferred_ts = inferred_ts.simplify(keep_unary=False, filter_sites=False)
727730
return ts, inferred_ts
728731

729732

733+
def stabilise_node_ordering(ts):
734+
# Ensure all the node times are distinct so that they will have
735+
# stable IDs after simplifying. This could possibly also be done
736+
# by reversing the IDs within a time slice. This is used for comparing
737+
# tree sequences produced by perfect inference.
738+
tables = ts.dump_tables()
739+
times = tables.nodes.time
740+
for t in range(1, int(times[0])):
741+
index = np.where(times == t)[0]
742+
k = index.shape[0]
743+
times[index] += np.arange(k)[::-1] / k
744+
tables.nodes.time = times
745+
tables.sort()
746+
return tables.tree_sequence()
747+
748+
730749
def count_sample_child_edges(ts):
731750
"""
732751
Returns an array counting the number of edges where each sample is a child.

tsinfer/inference.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,6 @@ def match_samples(
550550
recombination=None, # See :class:`Matcher`
551551
mismatch=None, # See :class:`Matcher`
552552
precision=None,
553-
stabilise_node_ordering=False,
554553
extended_checks=False,
555554
engine=constants.C_ENGINE,
556555
progress_monitor=None,
@@ -636,9 +635,7 @@ def match_samples(
636635
# we sometimes assume they are in the same order as in the file
637636

638637
manager.match_samples(sample_indexes, sample_times)
639-
ts = manager.finalise(
640-
simplify=simplify, stabilise_node_ordering=stabilise_node_ordering
641-
)
638+
ts = manager.finalise(simplify=simplify)
642639
return ts
643640

644641

@@ -1641,7 +1638,7 @@ def match_samples(self, sample_indexes, sample_times):
16411638
progress_monitor.update()
16421639
progress_monitor.close()
16431640

1644-
def finalise(self, simplify, stabilise_node_ordering):
1641+
def finalise(self, simplify):
16451642
logger.info("Finalising tree sequence")
16461643
ts = self.get_samples_tree_sequence()
16471644
if simplify:
@@ -1650,20 +1647,6 @@ def finalise(self, simplify, stabilise_node_ordering):
16501647
"filter_individuals=False, keep_unary=True) on "
16511648
f"{ts.num_nodes} nodes and {ts.num_edges} edges"
16521649
)
1653-
if stabilise_node_ordering:
1654-
# Ensure all the node times are distinct so that they will have
1655-
# stable IDs after simplifying. This could possibly also be done
1656-
# by reversing the IDs within a time slice. This is used for comparing
1657-
# tree sequences produced by perfect inference.
1658-
tables = ts.dump_tables()
1659-
times = tables.nodes.time
1660-
for t in range(1, int(times[0])):
1661-
index = np.where(times == t)[0]
1662-
k = index.shape[0]
1663-
times[index] += np.arange(k)[::-1] / k
1664-
tables.nodes.time = times
1665-
tables.sort()
1666-
ts = tables.tree_sequence()
16671650
ts = ts.simplify(
16681651
samples=list(self.sample_id_map.values()),
16691652
filter_sites=False,

0 commit comments

Comments
 (0)