Skip to content

Commit

Permalink
add average ties for compare method
Browse files Browse the repository at this point in the history
  • Loading branch information
hfr1tz3 committed Dec 22, 2024
1 parent 3956713 commit 22fb7b5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
7 changes: 5 additions & 2 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def naive_node_span(ts):
return node_spans


def naive_compare(ts, other, transform=None):
def naive_compare(ts, other, transform=None, ties="average"):
"""
Ineffiecient but transparent function to compute dissimilarity
and root-mean-square-error between two tree sequences.
Expand Down Expand Up @@ -125,7 +125,10 @@ def f(t):
best_match_spans = np.zeros((ts.num_nodes,))
time_discrepancies = np.zeros((ts.num_nodes,))
for i, j in enumerate(best_match):
best_match_spans[i] = shared_spans[i, j]
if ties is 'average':
best_match_spans[i] = share_spans[i, j]/np.bincount(best_match)[j]
if ties is None:
best_match_spans[i] = shared_spans[i, j]
time_discrepancies[i] = time_array[i, j]
node_span = naive_node_span(ts)
total_node_spans = np.sum(node_span)
Expand Down
19 changes: 17 additions & 2 deletions tscompare/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __str__(self):
return out


def compare(ts, other, transform=None):
def compare(ts, other, transform=None, ties="average"):
"""
For two tree sequences `ts` and `other`,
this method returns an object of type :class:`.ARFResult`.
Expand Down Expand Up @@ -379,10 +379,22 @@ def compare(ts, other, transform=None):
root-mean-squared error (see :class:`.ARFResult`); the default
is `log(1 + t)`.
The callable `ties` is used to determine matching between nodes.
Current options are `None` and `average`.
The `None` option allows multiple nodes in `ts` to match with a
single node in `other`, the similarity between `ts` and `other`
can then exceed the total node span of `other`.
For each node in `other`, the `average` argument computes the
average shared span amongst all nodes in `ts` which are its match.
The similarity will then not exceed the total node span of `other`.
Default is `average`.
:param ts: The focal tree sequence.
:param other: The tree sequence we compare to.
:param transform: A callable that can take an array of times and
return another array of numbers.
:param ties: A callable that determines the matching process
between nodes.
:return: The three quantities above.
:rtype: ARFResult
"""
Expand Down Expand Up @@ -426,7 +438,10 @@ def f(t):
)
# Between each pair of nodes, find the maximum shared span
best_match = best_match_matrix.argmax(axis=1).A1
best_match_spans = shared_spans[np.arange(len(best_match)), best_match].reshape(-1)
if ties == "average"
best_match_spans = (shared_spans[np.arange(len(best_match)), best_match].reshape(-1))/np.bincount(best_match)
if ties is None:
best_match_spans = share_spans[np.arange(len(best_match)), best_match].reshape(-1)
total_match_span = np.sum(best_match_spans)
ts_node_spans = node_spans(ts)
total_span_ts = np.sum(ts_node_spans)
Expand Down

0 comments on commit 22fb7b5

Please sign in to comment.