Skip to content

Commit 71b6844

Browse files
committed
Fix windowing for cubed for regular chunks case
1 parent 3744d3c commit 71b6844

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

.github/workflows/cubed.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ jobs:
3030
3131
- name: Test with pytest
3232
run: |
33-
pytest -v sgkit/tests/test_{aggregation,association,hwe,pca}.py \
33+
pytest -v sgkit/tests/test_{aggregation,association,hwe,pca,window}.py \
3434
-k "test_count_call_alleles or \
3535
test_gwas_linear_regression or \
3636
test_hwep or \
3737
test_sample_stats or \
3838
(test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or \
3939
(test_variant_stats and not test_variant_stats__chunks[chunks2-False]) or \
40-
(test_pca__array_backend and tsqr)" or \
40+
(test_pca__array_backend and tsqr) or \
41+
(test_window and not 12-5-4-4)" \
4142
--use-cubed

sgkit/tests/test_window.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import re
22

33
import allel
4-
import dask.array as da
54
import numpy as np
65
import pandas as pd
76
import pytest
87
import xarray as xr
98

9+
import sgkit.distarray as da
1010
from sgkit import (
1111
simulate_genotype_call_dataset,
1212
window_by_interval,

sgkit/window.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import functools
12
from typing import Any, Callable, Hashable, Iterable, Optional, Tuple, Union
23

3-
import dask.array as da
44
import numpy as np
55
from xarray import Dataset
66

7+
import sgkit.distarray as da
78
from sgkit import variables
89
from sgkit.model import get_contigs, num_contigs
910
from sgkit.utils import conditional_merge_datasets, create_dataset
@@ -510,8 +511,15 @@ def window_statistic(
510511
and len(window_stops) == 1
511512
and window_stops == np.array([values.shape[0]])
512513
):
514+
out = da.map_blocks(
515+
functools.partial(statistic, **kwargs),
516+
values,
517+
dtype=dtype,
518+
chunks=values.chunks[1:],
519+
drop_axis=0,
520+
)
513521
# call expand_dims to add back window dimension (size 1)
514-
return da.expand_dims(statistic(values, **kwargs), axis=0)
522+
return da.expand_dims(out, axis=0)
515523

516524
window_lengths = window_stops - window_starts
517525
depth = np.max(window_lengths) # type: ignore[no-untyped-call]
@@ -536,10 +544,10 @@ def window_statistic(
536544

537545
chunk_offsets = _sizes_to_start_offsets(windows_per_chunk)
538546

539-
def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike:
540-
if block_info is None or len(block_info) == 0:
547+
def blockwise_moving_stat(x: ArrayLike, block_id: Any = None) -> ArrayLike:
548+
if block_id is None:
541549
return np.array([])
542-
chunk_number = block_info[0]["chunk-location"][0]
550+
chunk_number = block_id[0]
543551
chunk_offset_start = chunk_offsets[chunk_number]
544552
chunk_offset_stop = chunk_offsets[chunk_number + 1]
545553
chunk_window_starts = rel_window_starts[chunk_offset_start:chunk_offset_stop]
@@ -559,8 +567,9 @@ def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike:
559567
depth = {0: depth}
560568
# new chunks are same except in first axis
561569
new_chunks = tuple([tuple(windows_per_chunk)] + list(desired_chunks[1:])) # type: ignore
562-
return values.map_overlap(
570+
return da.map_overlap(
563571
blockwise_moving_stat,
572+
values,
564573
dtype=dtype,
565574
chunks=new_chunks,
566575
depth=depth,

0 commit comments

Comments
 (0)