Skip to content

Commit

Permalink
Fixup pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Dec 16, 2024
1 parent 90c1768 commit db15cfc
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 57 deletions.
6 changes: 6 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]
# Based directly on Black's recommendations:
# https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length
max-line-length = 81
select = A,C,E,F,W,B,B950
ignore = E203, E501, W503
29 changes: 29 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-merge-conflict
- id: debug-statements
- id: mixed-line-ending
- id: check-case-conflict
- id: check-yaml
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.14.0
hooks:
- id: reorder-python-imports
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
args: [--config=.flake8]
additional_dependencies: ["flake8-bugbear==24.12.12", "flake8-builtins==2.5.0"]
55 changes: 35 additions & 20 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""
Test tools for mapping between node sets of different tree sequences
"""

from collections import defaultdict
from itertools import combinations

Expand Down Expand Up @@ -99,10 +98,13 @@ def naive_compare(ts, other, transform=None):
Ineffiecient but transparent function to compute dissimilarity
and root-mean-square-error between two tree sequences.
"""

def f(t):
return np.log(1 + t)
if transform is not None:
f = transform

if transform is None:
transform = f

shared_spans = naive_shared_node_spans(ts, other).toarray()
max_span = np.max(shared_spans, axis=1)
assert len(max_span) == ts.num_nodes
Expand All @@ -115,7 +117,9 @@ def f(t):
else:
for j in range(other.num_nodes):
if shared_spans[i, j] == max_span[i]:
time_array[i, j] = np.abs(f(ts.nodes_time[i]) - f(other.nodes_time[j]))
time_array[i, j] = np.abs(
transform(ts.nodes_time[i]) - transform(other.nodes_time[j])
)
dissimilarity_matrix[i, j] = 1 / (1 + time_array[i, j])
best_match = np.argmax(dissimilarity_matrix, axis=1)
best_match_spans = np.zeros((ts.num_nodes,))
Expand Down Expand Up @@ -180,9 +184,7 @@ def test_node_spans(self, ts):
naive_ns = naive_node_span(ts)
assert np.allclose(eval_ns, naive_ns)

@pytest.mark.parametrize(
"pair", combinations([true_simpl, true_unary], 2)
)
@pytest.mark.parametrize("pair", combinations([true_simpl, true_unary], 2))
def test_shared_spans(self, pair):
"""
Check that efficient implementation returns same answer as naive
Expand All @@ -205,13 +207,16 @@ def test_match_self(self, ts):
assert np.allclose(time, ts.nodes_time)
assert np.array_equal(hit, np.arange(ts.num_nodes))


class TestDissimilarity:

def verify_compare(self, ts, other, transform=None):
match_span, ts_span, other_span, rmse = naive_compare(ts, other, transform=transform)
match_span, ts_span, other_span, rmse = naive_compare(
ts, other, transform=transform
)
dis = tscompare.compare(ts, other, transform=transform)
assert np.isclose(1.0 - match_span/ts_span, dis.arf)
assert np.isclose(match_span/other_span, dis.tpr)
assert np.isclose(1.0 - match_span / ts_span, dis.arf)
assert np.isclose(match_span / other_span, dis.tpr)
assert np.isclose(ts_span - match_span, dis.dissimilarity)
assert np.isclose(ts_span, dis.total_span[0])
assert np.isclose(other_span, dis.total_span[1])
Expand All @@ -235,15 +240,15 @@ def test_basic_comparison(self, pair):
def test_zero_dissimilarity(self, pair):
dis = tscompare.compare(pair[0], pair[1])
assert np.isclose(dis.dissimilarity, 0)
assert np.isclose(dis.arf, 0)
assert np.isclose(dis.arf, 0)
assert np.isclose(dis.rmse, 0)

def test_transform(self):
dis1 = tscompare.compare(true_simpl, true_simpl, transform=lambda t: t)
dis2 = tscompare.compare(true_simpl, true_simpl, transform=None)
assert dis1.dissimilarity == dis2.dissimilarity
assert dis1.rmse == dis2.rmse
self.verify_compare(true_simpl, true_ext, transform=lambda t: 1/(1 + t))
self.verify_compare(true_simpl, true_ext, transform=lambda t: 1 / (1 + t))

def get_simple_ts(self, samples=None, time=False, span=False, no_match=False):
# A simple tree sequence we can use to properly test various
Expand Down Expand Up @@ -397,12 +402,17 @@ def test_rmse(self):
true_total_span = 46
assert dis.total_span[0] == true_total_span
assert dis.total_span[1] == true_total_span

def f(t):
return np.log(1 + t)
true_rmse = np.sqrt((
2 * 6 * (f(500) - f(200))**2 # nodes 4, 5
+ 2 * 2 * (f(750) - f(600))**2 # nodes, 7, 8
) / true_total_span)

true_rmse = np.sqrt(
(
2 * 6 * (f(500) - f(200)) ** 2 # nodes 4, 5
+ 2 * 2 * (f(750) - f(600)) ** 2 # nodes, 7, 8
)
/ true_total_span
)
assert np.isclose(dis.arf, 0.0)
assert np.isclose(dis.tpr, 1.0)
assert np.isclose(dis.dissimilarity, 0.0)
Expand All @@ -414,12 +424,17 @@ def test_value_and_error(self):
dis = tscompare.compare(ts, other)
true_total_spans = (46, 47)
assert dis.total_span == true_total_spans

def f(t):
return np.log(1 + t)
true_rmse = np.sqrt((
2 * 6 * (f(500) - f(200))**2 # nodes 4, 5
+ 2 * 2 * (f(750) - f(600))**2 # nodes, 7, 8
) / true_total_spans[0])

true_rmse = np.sqrt(
(
2 * 6 * (f(500) - f(200)) ** 2 # nodes 4, 5
+ 2 * 2 * (f(750) - f(600)) ** 2 # nodes, 7, 8
)
/ true_total_spans[0]
)
assert np.isclose(dis.arf, 4 / true_total_spans[0])
assert np.isclose(dis.tpr, (true_total_spans[0] - 4) / true_total_spans[1])
assert np.isclose(dis.dissimilarity, 4)
Expand Down
9 changes: 7 additions & 2 deletions tscompare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@
"""
Tools for comparing tree sequences
"""
from .methods import compare, node_spans, CladeMap, shared_node_spans, match_node_ages, ARFResult
from .provenance import __version__
from .methods import ARFResult # noqa F401
from .methods import CladeMap # noqa F401
from .methods import compare # noqa F401
from .methods import match_node_ages # noqa F401
from .methods import node_spans # noqa F401
from .methods import shared_node_spans # noqa F401
from .provenance import __version__ # noqa F401
70 changes: 36 additions & 34 deletions tscompare/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
"""
Tools for comparing node times between tree sequences with different node sets
"""

from dataclasses import dataclass
import copy
from collections import defaultdict
from itertools import groupby, product
from dataclasses import dataclass
from itertools import product

import copy
import numpy as np
import scipy.sparse

import tskit


def node_spans(ts):
"""
Returns the array of "node spans", i.e., the `j`th entry gives
Expand Down Expand Up @@ -97,7 +96,7 @@ def _propagate(self, edge, downdate=False):
node = self.tree.parent(node)
return nodes

def next(self):
def next(self): # noqa: A003
"""
Advance to the next tree, returning the difference between trees as a
dictionary of the form `node : (last_clade, next_clade)`
Expand Down Expand Up @@ -254,18 +253,18 @@ def shared_node_spans(ts, other):

def match_node_ages(ts, other):
"""
For each node in `ts`, return the age of a matched node from `other`. Node
matching is accomplished as described in :func:`.compare`.
For each node in `ts`, return the age of a matched node from `other`. Node
matching is accomplished as described in :func:`.compare`.
Returns a tuple of three vectors of length `ts.num_nodes`, in this order:
the age of the best matching node in `other`;
the proportion of the node span in `ts` that is covered by the best match;
and the node id of the best match in `other`.
Returns a tuple of three vectors of length `ts.num_nodes`, in this order:
the age of the best matching node in `other`;
the proportion of the node span in `ts` that is covered by the best match;
and the node id of the best match in `other`.
:return: A tuple of arrays of length `ts.num_nodes` containing
(time of matching node, proportion overlap, and node ID of match).
:return: A tuple of arrays of length `ts.num_nodes` containing
(time of matching node, proportion overlap, and node ID of match).
"""

shared_spans = shared_node_spans(ts, other)
Expand All @@ -283,7 +282,6 @@ def match_node_ages(ts, other):

@dataclass
class ARFResult:

"""
The result of a call to tscompare.compare(ts, other),
returning metrics associated with the ARG Robinson-Foulds
Expand All @@ -302,7 +300,7 @@ class ARFResult:
`dissimilarity`:
The total span of `ts` that is not represented in `other`.
`total_span`:
The total of all node spans of the two tree sequences, in order (`ts`, `other`).
Expand All @@ -314,6 +312,7 @@ class ARFResult:
`transform`:
The transformation function used to transform times for computing `rmse`.
"""

arf: float
tpr: float
dissimilarity: float
Expand All @@ -329,25 +328,26 @@ def __str__(self):
out += f" ARF: {100*self.arf:.2f}%\n"
out += f" TPR: {100*self.tpr:.2f}%\n"
out += f" dissimilarity: {self.dissimilarity}\n"
out += f" total span (ts, other): {self.total_span[0]}, {self.total_span[1]}\n"
out += (
f" total span (ts, other): {self.total_span[0]}, {self.total_span[1]}\n"
)
out += f" time RMSE: {self.rmse}\n"
return out


def compare(ts, other, transform=None):

"""
For two tree sequences `ts` and `other`,
this method returns an object of type :class:`.ARFResult`.
The values reported summarize the degree to which nodes in `ts`
"match" corresponding nodes in `other`.
To match nodes,
for each node in `ts`, the best matching node(s) from `other`
has the longest matching span using :func:`.shared_node_spans`.
If there are multiple matches with the same longest shared span
for a single node, the best match is the match that is closest in time.
Then, :class:`.ARFResult` contains:
- (`dissimilarity`)
Expand All @@ -356,8 +356,8 @@ def compare(ts, other, transform=None):
samples as its best match in `other`.
- (`arf`)
The fraction of the total span of `ts` over which each nodes'
descendant sample set does not match its' best match's descendant
The fraction of the total span of `ts` over which each nodes'
descendant sample set does not match its' best match's descendant
sample set (i.e., the total *un*-matched span divided by the total
span of `ts`).
Expand Down Expand Up @@ -387,8 +387,11 @@ def compare(ts, other, transform=None):
:rtype: ARFResult
"""

def f(t):
return np.log(1 + t)

if transform is None:
transform = lambda t: np.log(1 + t)
transform = f

shared_spans = shared_node_spans(ts, other)
# Find all potential matches for a node based on max shared span length
Expand All @@ -403,7 +406,9 @@ def compare(ts, other, transform=None):
# determine best matches with the best_match_matrix
ts_times = ts.nodes_time[row_ind[match]]
other_times = other.nodes_time[col_ind[match]]
time_difference = np.absolute(np.asarray(transform(ts_times) - transform(other_times)))
time_difference = np.absolute(
np.asarray(transform(ts_times) - transform(other_times))
)
# If a node x in `ts` has no match then we set time_difference to zero
# This node then does not effect the rmse
for j in range(len(shared_spans.data[match])):
Expand Down Expand Up @@ -438,13 +443,10 @@ def compare(ts, other, transform=None):
product = np.multiply((time_discrepancies**2), ts_node_spans)
rmse = np.sqrt(np.sum(product) / total_span_ts)
return ARFResult(

arf = 1.0 - total_match_span / total_span_ts,
tpr = total_match_span / total_span_other,

dissimilarity = total_span_ts - total_match_span,
total_span = (total_span_ts, total_span_other),
rmse = rmse,
transform = transform,
arf=1.0 - total_match_span / total_span_ts,
tpr=total_match_span / total_span_other,
dissimilarity=total_span_ts - total_match_span,
total_span=(total_span_ts, total_span_other),
rmse=rmse,
transform=transform,
)

1 change: 0 additions & 1 deletion tscompare/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@
__version__ = get_version(root="..", relative_to=__file__)
except ImportError:
pass

0 comments on commit db15cfc

Please sign in to comment.