Skip to content

Commit

Permalink
added test cases for 'average' ties
Browse files Browse the repository at this point in the history
  • Loading branch information
hfr1tz3 committed Dec 23, 2024
1 parent 4921a5d commit a8be5b4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
26 changes: 17 additions & 9 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ class TestDissimilarity:

def verify_compare(self, ts, other, transform=None):
match_span, ts_span, other_span, rmse = naive_compare(
ts, other, transform=transform
ts, other, transform=transform, ties=ties,
)
dis = tscompare.compare(ts, other, transform=transform)
dis = tscompare.compare(ts, other, transform=transform, ties=ties)
assert np.isclose(1.0 - match_span / ts_span, dis.arf)
assert np.isclose(match_span / other_span, dis.tpr)
assert np.isclose(ts_span - match_span, dis.dissimilarity)
Expand Down Expand Up @@ -247,11 +247,18 @@ def test_zero_dissimilarity(self, pair):
assert np.isclose(dis.rmse, 0)

def test_transform(self):
dis1 = tscompare.compare(true_simpl, true_simpl, transform=lambda t: t)
dis2 = tscompare.compare(true_simpl, true_simpl, transform=None)
dis1 = tscompare.compare(true_simpl, true_simpl, transform=lambda t: t, ties=None)
dis2 = tscompare.compare(true_simpl, true_simpl, transform=None, ties=None)
assert dis1.dissimilarity == dis2.dissimilarity
assert dis1.rmse == dis2.rmse
self.verify_compare(true_simpl, true_ext, transform=lambda t: 1 / (1 + t))
self.verify_compare(true_simpl, true_ext, transform=lambda t: 1 / (1 + t), ties=None)

def test_ties(self):
dis1 = tscompare.compare(true_simpl, true_ext, transform=None, ties="average")
dis2 = tscompare.compare(true_simpl, true_ext, transform=None, ties=None)
assert dis1.dissimilarity == dis2.dissimilarity
assert dis1.rmse == dis2.rmse
self.verify_compare(true_ext, true_simpl, transform=None, ties="average")

def get_simple_ts(self, samples=None, time=False, span=False, no_match=False):
# A simple tree sequence we can use to properly test various
Expand Down Expand Up @@ -389,19 +396,20 @@ def test_with_no_match(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(span=True, time=True, no_match=True)
self.verify_compare(ts, other)
self.verify_compare(ts, other, transform=lambda t: np.sqrt(1 + t))
self.verify_compare(ts, other, transform=lambda t: np.sqrt(1 + t), ties=None)
self.verify_compare(ts, other, transform=lambda t: np.sqrt(1 + t), ties="average")

def test_dissimilarity_value(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(span=True)
dis = tscompare.compare(ts, other)
dis = tscompare.compare(ts, other, transform=None, ties=None)
assert np.isclose(dis.arf, 4 / 46)
assert np.isclose(dis.rmse, 0.0)

def test_rmse(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(time=True)
dis = tscompare.compare(ts, other)
dis = tscompare.compare(ts, other, transform=None, ties=None)
true_total_span = 46
assert dis.total_span[0] == true_total_span
assert dis.total_span[1] == true_total_span
Expand All @@ -424,7 +432,7 @@ def f(t):
def test_value_and_error(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(span=True, time=True)
dis = tscompare.compare(ts, other)
dis = tscompare.compare(ts, other, transform=None, ties=None)
true_total_spans = (46, 47)
assert dis.total_span == true_total_spans

Expand Down
2 changes: 1 addition & 1 deletion tscompare/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def f(t):
)
# Between each pair of nodes, find the maximum shared span
best_match = best_match_matrix.argmax(axis=1).A1
if ties == "average"
if ties == "average":
best_match_spans = (shared_spans[np.arange(len(best_match)), best_match].reshape(-1))/np.bincount(best_match)
if ties is None:
best_match_spans = share_spans[np.arange(len(best_match)), best_match].reshape(-1)
Expand Down

0 comments on commit a8be5b4

Please sign in to comment.