From 22fb7b51d1b706a14573ed8d88b54c688620840c Mon Sep 17 00:00:00 2001 From: Halley Fritze <97766437+hfr1tz3@users.noreply.github.com> Date: Sun, 22 Dec 2024 11:27:34 -0800 Subject: [PATCH] add average ties for compare method --- tests/test_methods.py | 7 +++++-- tscompare/methods.py | 19 +++++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/test_methods.py b/tests/test_methods.py index e4654f0..82d27cf 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -93,7 +93,7 @@ def naive_node_span(ts): return node_spans -def naive_compare(ts, other, transform=None): +def naive_compare(ts, other, transform=None, ties="average"): """ Ineffiecient but transparent function to compute dissimilarity and root-mean-square-error between two tree sequences. @@ -125,7 +125,10 @@ def f(t): best_match_spans = np.zeros((ts.num_nodes,)) time_discrepancies = np.zeros((ts.num_nodes,)) for i, j in enumerate(best_match): - best_match_spans[i] = shared_spans[i, j] + if ties is 'average': + best_match_spans[i] = share_spans[i, j]/np.bincount(best_match)[j] + if ties is None: + best_match_spans[i] = shared_spans[i, j] time_discrepancies[i] = time_array[i, j] node_span = naive_node_span(ts) total_node_spans = np.sum(node_span) diff --git a/tscompare/methods.py b/tscompare/methods.py index ee92669..6285e35 100644 --- a/tscompare/methods.py +++ b/tscompare/methods.py @@ -335,7 +335,7 @@ def __str__(self): return out -def compare(ts, other, transform=None): +def compare(ts, other, transform=None, ties="average"): """ For two tree sequences `ts` and `other`, this method returns an object of type :class:`.ARFResult`. @@ -379,10 +379,22 @@ def compare(ts, other, transform=None): root-mean-squared error (see :class:`.ARFResult`); the default is `log(1 + t)`. + The callable `ties` is used to determine matching between nodes. + Current options are `None` and `average`. + The `None` option allows multiple nodes in `ts` to match with a + single node in `other`, the similarity between `ts` and `other` + can then exceed the total node span of `other`. + For each node in `other`, the `average` argument computes the + average shared span amongst all nodes in `ts` which are its match. + The similarity will then not exceed the total node span of `other`. + Default is `average`. + :param ts: The focal tree sequence. :param other: The tree sequence we compare to. :param transform: A callable that can take an array of times and return another array of numbers. + :param ties: A callable that determines the matching process + between nodes. :return: The three quantities above. :rtype: ARFResult """ @@ -426,7 +438,10 @@ def f(t): ) # Between each pair of nodes, find the maximum shared span best_match = best_match_matrix.argmax(axis=1).A1 - best_match_spans = shared_spans[np.arange(len(best_match)), best_match].reshape(-1) + if ties == "average" + best_match_spans = (shared_spans[np.arange(len(best_match)), best_match].reshape(-1))/np.bincount(best_match) + if ties is None: + best_match_spans = share_spans[np.arange(len(best_match)), best_match].reshape(-1) total_match_span = np.sum(best_match_spans) ts_node_spans = node_spans(ts) total_span_ts = np.sum(ts_node_spans)