Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time windows in statistics #2948

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

Conversation

petrelharp
Copy link
Contributor

Here @tforest and I are starting in on adding time windows to statistics. We're starting with what was sketched out in #683, and will explain things in more detail here when we're farther along (ignore this for now).

Copy link

codecov bot commented May 9, 2024

Codecov Report

Attention: Patch coverage is 63.91753% with 35 lines in your changes missing coverage. Please review.

Project coverage is 89.76%. Comparing base (16de381) to head (59ea266).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
c/tskit/trees.c 53.48% 18 Missing and 2 partials ⚠️
python/tskit/trees.py 66.66% 7 Missing and 6 partials ⚠️
python/_tskitmodule.c 86.66% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2948      +/-   ##
==========================================
- Coverage   89.85%   89.76%   -0.09%     
==========================================
  Files          29       29              
  Lines       32128    32200      +72     
  Branches     5763     5781      +18     
==========================================
+ Hits        28868    28905      +37     
- Misses       1859     1885      +26     
- Partials     1401     1410       +9     
Flag Coverage Δ
c-tests 86.59% <53.48%> (-0.11%) ⬇️
lwt-tests 80.78% <ø> (ø)
python-c-tests 89.03% <86.66%> (-0.02%) ⬇️
python-tests 98.81% <66.66%> (-0.17%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/_tskitmodule.c 89.03% <86.66%> (-0.02%) ⬇️
python/tskit/trees.py 98.27% <66.66%> (-0.53%) ⬇️
c/tskit/trees.c 90.38% <53.48%> (-0.30%) ⬇️

@petrelharp
Copy link
Contributor Author

petrelharp commented May 17, 2024

Note: it is not clear how to do this for site statistics, since the site stat is of the form
$$\sum_a f(w_a)$$
where the sum is over alleles, and $w_a$ is the weight of all samples with allele $a$;
however, it is mutations that have times, not alleles.

The proposal will probably be to compute a site stat that sums over mutations, not alleles, but we'll start with branch stats only for now.

@petrelharp
Copy link
Contributor Author

Next step:

  • do the AFS first, since it's less tangled up

Also maybe:

  • allow ts.decapitate( ) to take inf as an argument (that does nothing) ?

@andrewkern
Copy link
Member

a small nudge here that i mentioned to @petrelharp in passing-- it would be great to have an expectation from theory as to what time stratified quantities like the SFS should be under the (standard, neutral) coalescent

@tforest
Copy link
Collaborator

tforest commented Jul 15, 2024

Some thoughts after working on time windows.

After these edits the moment the output of, let's say, the AFS is a still 2D array of windows, same for time_windows, when using either of them individually. However, when using windows and time_windows at the same time, the output is a 3D array, with the following shape: [num_windows][num_time_windows][sample_size]. When windows or time_windows are None, associated dimensions are dropped accordingly.
As there is now two types of windows, it will become ambiguous that the historical "windows" parameter is in fact corresponding specifically to genomic spanning windows. We did not renamed it for now though, as it would break previous behavior.

Some ideas:

  • Add new benchmarks for summary stats to see if the implemented features are optimized both in terms of computational space and time complexity.
  • Add some plots for summary stats to observe how time windows impact them.

@petrelharp
Copy link
Contributor Author

A note on the potential confusion between windows and time_windows - often one endpoint of the time_windows will be Inf, so if we make sure we produce an informative error if the windows aren't finite, we'll help people avoid the mistake.

Copy link
Contributor Author

@petrelharp petrelharp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! One question about a possible refactor, and suggesting moving the "general stat" stuff to a different PR.

for u in tree.nodes()
)
sigma[tree.index, j, :] = s * tree.span
for j in range(1, len(time_windows) - 1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for j in range(1, len(time_windows) - 1):
for j in range(1, tw):

return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
if drop_time_windows:
# beware: this assumes the first dimension is windows
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment can be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm but perhaps replaced by

assert len(out.shape) == 3

@@ -144,39 +144,93 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
return A


# Timewindows test
def naive_branch_general_stat(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks good; however, I think this PR is going to be just about the AFS, not general stats - so, maybe this code should be put aside in a separate PR?

if polarised:
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
sigma = np.zeros((ts.num_trees, tw, m))
for j, upper_time in enumerate(time_windows[1:]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that time_windows[0] is 0, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be fixed by setting sigma not to 0 but to -1 times the value calculated from ts.decapitate(time_windows[0]).

python/tests/test_tree_stats.py Show resolved Hide resolved
Comment on lines 3699 to 3700
# Warning: when using Windows and TimeWindows,
# the output has three dimensions
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete this

c = fold(c, out_dim)
index = tuple([window_index] + list(c))
result[index] += x
def update_result(window_index, u, right, time_windows):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time_windows isn't being changed by this function

Suggested change
def update_result(window_index, u, right, time_windows):
def update_result(window_index, u, right):

# interval between child and parent inside the window
t_v = branch_length[u] + time[u]
tw_branch_length = min(time_windows[k_tw + 1], t_v) - max(
time_windows[0], time[u]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops?

Suggested change
time_windows[0], time[u]
time_windows[k_tw], time[u]

for k_tw, _ in enumerate(time_windows[:-1]):
if 0 < count[u, -1] < ts.num_samples:
# interval between child and parent inside the window
t_v = branch_length[u] + time[u]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're losing the advantage of cacheing branch_length[u]; this might as well be time[v] (with v passed in also).

Alternatives:

  1. let branch_length be a (num nodes x num time windows) array instead of just a vector, so that we'd have
tw_branch_length = branch_length[u, k_tw]
  1. Do the calculation down a few lines, something like this (this is not right):
u = edge.child
v = edge.parent
t_c = time[u]
t_p = time[v]
time_window_index = 0
while t_p < time_windows[time_window_index + 1]:
    while v != -1:
        tw_branch_length = min(time_windows[k_tw + 1], t_p) - max(time_windows[k_tw], t_c)
        update_result(window_index, time_window_index, v, t_left, tw_branch_length)
        count[v] -= count[u]
        t_c = t_p
        v = parent[v]
        t_p = time[v]
    time_window_index += 1

The advantage to this is that computation isn't increased by a factor of (num time windows). The disadvantage might be that the code is harder to understand?

window_index += 1
tree_index += 1

assert window_index == windows.shape[0] - 1
if span_normalise:
for j in range(num_windows):
result[j] /= windows[j + 1] - windows[j]

if drop_time_windows:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see suggestions above

@benjeffery
Copy link
Member

I've added this work to the next release milestone. Hoping to get a release out in a week or two, if that is too ambitious for this let me know.

@petrelharp
Copy link
Contributor Author

Probably too ambitious, but we might have something in by then.

Copy link
Contributor Author

@petrelharp petrelharp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Some suggestions, mostly minor; let's discuss getting the tests in there.

@@ -7637,6 +7637,7 @@ def parse_windows(self, windows):
# Note: need to make sure windows is a string or we try to compare the
# target with a numpy array elementwise.
if windows is None:
# initiate default spanning windows
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# initiate default spanning windows

Comment on lines +7683 to +7686
if strip_win:
stat = stat[0, :, :]
elif strip_timewin:
stat = stat[:, 0, :]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like you can't have both, ie windows=None, time_windows=None?

Comment on lines +7786 to +7788
if (stat.shape == () and windows is None) or (
stat.shape == () and time_windows is None
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (stat.shape == () and windows is None) or (
stat.shape == () and time_windows is None
):
if (stat.shape == () and windows is None and time_windows is None):

I think the intention of this rule is so that if you do like

ts.diversity([0,1,2])

then you get a single number, not a length-1 array, but if anyone is supplying windows explicitly (or time windows!) then they should get an array with the number of dimensions they expect.

We should write the bit in the docs that includes time windows, so we've got this clear?

@@ -9077,7 +9077,7 @@ parse_windows(
npy_intp *shape;

windows_array = (PyArrayObject *) PyArray_FROMANY(
windows, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY);
windows, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change (and others like it) done by linting?

"span_normalise", "polarised", NULL };
PyObject *sample_set_sizes = NULL;
PyObject *sample_sets = NULL;
PyObject *windows = NULL;
char *mode = NULL;
PyObject *time_windows = NULL;
char *mode = "NULL";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
char *mode = "NULL";
char *mode = NULL;

}
increment_nd_array_value(afs, num_sample_sets, result_dims, coordinate, x);
if (parent[u] != -1){
t_v = time[parent[u]];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps also t_u here?

if (!polarised){
fold(coordinate, result_dims, num_sample_sets);
}
tw_branch_length = MIN(time_windows[time_window_index + 1], t_v) - MAX(time_windows[0], time[u]);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be

Suggested change
tw_branch_length = MIN(time_windows[time_window_index + 1], t_v) - MAX(time_windows[0], time[u]);
tw_branch_length = MIN(time_windows[time_window_index + 1], t_v) - MAX(time_windows[time_window_index], time[u]);

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm - the tests below should be catching this if it is indeed wrong, but it sure looks wrong to me - I'm not sure what's going on?

if (parent[u] != -1){
t_v = time[parent[u]];
if (0 < all_samples && all_samples < self->num_samples) {
for (time_window_index = 0; time_window_index < num_time_windows; time_window_index++){
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of edges are recent, so we might avoid substantial work if we do like

	    time_window_index = 0;
	    while (time_window_index < num_time_windows && time_windows[time_window_index] < t_v){
                   ...
                 time_window_index++;
             }



class TestTimeWindows(TestBranchAlleleFrequencySpectrum):
def test_four_taxa_test_case(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't really be in this class, since it's testing general_stat, not the AFS; perhaps leave a comment? Or move it along with the general_stat code above to a new PR for future work?

)
self.assertArrayAlmostEqual(x, true_x)

def test_afs_branch(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very useful, but it's hard to tell exactly what's being tested. For instance, there's no call to ts.allele_frequency_spectrum here, I think? Perhaps this could be rearranged? Simplified? Commented? There's also some references to self.mode, which might be confusing since this is branch-only?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants