diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml new file mode 100644 index 000000000..2b9cde4e8 --- /dev/null +++ b/.github/workflows/cubed.yml @@ -0,0 +1,33 @@ +name: Cubed + +on: + push: + pull_request: + # manual trigger + workflow_dispatch: + +jobs: + build: + # This workflow only runs on the origin org + # if: github.repository_owner == 'sgkit-dev' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install deps and sgkit + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt -r requirements-dev.txt + python -m pip install -U git+https://github.com/cubed-dev/cubed.git -U git+https://github.com/cubed-dev/cubed-xarray.git -U git+https://github.com/pydata/xarray.git + + - name: Test with pytest + run: | + pytest -v sgkit/tests/test_aggregation.py -k "test_count_call_alleles" --use-cubed diff --git a/conftest.py b/conftest.py index bcccb6e8a..9343393f3 100644 --- a/conftest.py +++ b/conftest.py @@ -2,6 +2,30 @@ collect_ignore_glob = ["benchmarks/**", "sgkit/io/vcf/*.py", ".github/scripts/*.py"] +def pytest_addoption(parser): + parser.addoption( + "--use-cubed", action="store_true", default=False, help="run with cubed" + ) + + +def use_cubed(): + import dask + import xarray as xr + + # set xarray to use cubed by default + xr.set_options(chunk_manager="cubed") + + # ensure that dask compute raises if it is ever called + class AlwaysRaiseScheduler: + def __call__(self, dsk, keys, **kwargs): + raise RuntimeError("Dask 'compute' was called") + + dask.config.set(scheduler=AlwaysRaiseScheduler()) + + def pytest_configure(config) -> None: # type: ignore # Add "gpu" marker config.addinivalue_line("markers", "gpu:Run tests that run on GPU") + + if config.getoption("--use-cubed"): + use_cubed() diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 41adcc221..8bbb738b5 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -76,17 +76,24 @@ def count_call_alleles( variables.validate(ds, {call_genotype: variables.call_genotype_spec}) n_alleles = ds.sizes["alleles"] - G = da.asarray(ds[call_genotype]) - shape = (G.chunks[0], G.chunks[1], n_alleles) # use numpy array to avoid dask task dependencies between chunks N = np.empty(n_alleles, dtype=np.uint8) + AC = xr.apply_ufunc( + count_alleles, + ds[call_genotype], + N, + input_core_dims=[["ploidy"], ["alleles"]], + output_core_dims=[["alleles"]], + exclude_dims={"ploidy"}, + dask="parallelized", + dask_gufunc_kwargs=dict(allow_rechunk=False), + output_dtypes=np.uint8, + ) new_ds = create_dataset( { variables.call_allele_count: ( ("variants", "samples", "alleles"), - da.map_blocks( - count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2 - ), + AC.data, ) } ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index a5e68645e..d4d3db1e5 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -139,8 +139,10 @@ def test_count_variant_alleles__chunked(using): calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) ac1 = count_variant_alleles(ds, using=using) - # Coerce from numpy to multiple chunks in all dimensions - ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) + # Coerce from numpy to multiple chunks in all non-core dimensions + ds["call_genotype"] = ds["call_genotype"].chunk( + chunks={"variants": 5, "samples": 5} + ) ac2 = count_variant_alleles(ds, using=using) assert isinstance(ac2["variant_allele_count"].data, da.Array) xr.testing.assert_equal(ac1, ac2) @@ -265,10 +267,12 @@ def test_count_call_alleles__chunked(): calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) ac1 = count_call_alleles(ds) - # Coerce from numpy to multiple chunks in all dimensions - ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) + # Coerce from numpy to multiple chunks in all non-core dimensions + ds["call_genotype"] = ds["call_genotype"].chunk( + chunks={"variants": 5, "samples": 5} + ) ac2 = count_call_alleles(ds) - assert isinstance(ac2["call_allele_count"].data, da.Array) + assert hasattr(ac2["call_allele_count"].data, "chunks") xr.testing.assert_equal(ac1, ac2) diff --git a/sgkit/tests/test_ibs.py b/sgkit/tests/test_ibs.py index e53fab7b7..c68a05569 100644 --- a/sgkit/tests/test_ibs.py +++ b/sgkit/tests/test_ibs.py @@ -1,6 +1,5 @@ import pathlib -import dask.array as da import numpy as np import pytest @@ -42,13 +41,9 @@ def test_identity_by_state__diploid_biallelic(method, chunks, skipna): pass elif method == "frequencies": ds = count_call_alleles(ds) - ds["call_allele_count"] = ( - ds.call_allele_count.dims, - ds.call_allele_count.data.rechunk(chunks), - ) + ds["call_allele_count"] = ds["call_allele_count"].chunk(chunks) else: - gt = da.array(ds.call_genotype.data) - ds["call_genotype"] = ds.call_genotype.dims, gt.rechunk(chunks) + ds["call_genotype"] = ds["call_genotype"].chunk(chunks) ds = identity_by_state(ds, method=method, skipna=skipna) actual = ds.stat_identity_by_state.values expect = np.nanmean( @@ -93,13 +88,9 @@ def test_identity_by_state__tetraploid_multiallelic(method, chunks, skipna): if chunks is None: pass elif method == "frequencies": - ds["call_allele_count"] = ( - ds.call_allele_count.dims, - ds.call_allele_count.data.rechunk(chunks), - ) + ds["call_allele_count"] = ds["call_allele_count"].chunk(chunks) else: - gt = da.array(ds.call_genotype.data) - ds["call_genotype"] = ds.call_genotype.dims, gt.rechunk(chunks) + ds["call_genotype"] = ds["call_genotype"].chunk(chunks) ds = identity_by_state(ds, method=method, skipna=skipna) actual = ds.stat_identity_by_state.values if skipna: @@ -147,7 +138,7 @@ def test_identity_by_state__reference_implementation(ploidy, method, chunks, ski seed=0, ) ds = call_allele_frequencies(ds) - ds.chunk(variants=chunks[0], samples=chunks[1], alleles=chunks[2]) + ds = ds.chunk(variants=chunks[0], samples=chunks[1], alleles=chunks[2]) # reference implementation AF = ds.call_allele_frequency.data if skipna: diff --git a/sgkit/tests/test_popgen.py b/sgkit/tests/test_popgen.py index 50fc9bb4c..32a9a8e6a 100644 --- a/sgkit/tests/test_popgen.py +++ b/sgkit/tests/test_popgen.py @@ -533,7 +533,7 @@ def test_Garud_h__raise_on_no_windows(): @pytest.mark.filterwarnings("ignore::RuntimeWarning") -@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (2, 2))]) +@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (4,))]) def test_observed_heterozygosity(chunks): ds = simulate_genotype_call_dataset( n_variant=4, @@ -599,7 +599,7 @@ def test_observed_heterozygosity(chunks): @pytest.mark.filterwarnings("ignore::RuntimeWarning") -@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (2, 2))]) +@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (4,))]) @pytest.mark.parametrize( "cohorts,expectation", [