Skip to content

Commit 45f1267

Browse files
tomwhitemergify[bot]
authored andcommitted
Add concat algorithm parameter to vcf_to_zarr (#365)
Use variable-length strings for storing alleles in Zarr #643
1 parent 19f0b1b commit 45f1267

File tree

4 files changed

+141
-31
lines changed

4 files changed

+141
-31
lines changed

sgkit/io/vcf/vcf_reader.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
import xarray as xr
2323
from cyvcf2 import VCF, Variant
2424
from numcodecs import PackBits
25+
from typing_extensions import Literal
2526

2627
from sgkit import variables
2728
from sgkit.io.dataset import load_dataset
2829
from sgkit.io.utils import zarrs_to_dataset
2930
from sgkit.io.vcf import partition_into_regions
3031
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
31-
from sgkit.io.vcfzarr_reader import vcf_number_to_dimension_and_size
32+
from sgkit.io.vcfzarr_reader import (
33+
concat_zarrs_optimized,
34+
vcf_number_to_dimension_and_size,
35+
)
3236
from sgkit.model import (
3337
DIM_PLOIDY,
3438
DIM_SAMPLE,
@@ -529,6 +533,7 @@ def vcf_to_zarr_parallel(
529533
fields: Optional[Sequence[str]] = None,
530534
exclude_fields: Optional[Sequence[str]] = None,
531535
field_defs: Optional[Dict[str, Dict[str, Any]]] = None,
536+
concat_algorithm: Optional[Literal["xarray_internal"]] = None,
532537
) -> None:
533538
"""Convert specified regions of one or more VCF files to zarr files, then concat, rechunk, write to zarr"""
534539

@@ -557,11 +562,15 @@ def vcf_to_zarr_parallel(
557562
field_defs=field_defs,
558563
)
559564

560-
ds = zarrs_to_dataset(paths, chunk_length, chunk_width, tempdir_storage_options)
561-
562-
# Ensure Dask task graph is efficient, see https://github.com/dask/dask/issues/5105
563-
with dask.config.set({"optimization.fuse.ave-width": dask_fuse_avg_width}):
564-
ds.to_zarr(output, mode="w")
565+
concat_zarrs(
566+
paths,
567+
output,
568+
concat_algorithm=concat_algorithm,
569+
chunk_length=chunk_length,
570+
chunk_width=chunk_width,
571+
storage_options=tempdir_storage_options,
572+
dask_fuse_avg_width=dask_fuse_avg_width,
573+
)
565574

566575

567576
def vcf_to_zarrs(
@@ -703,6 +712,64 @@ def vcf_to_zarrs(
703712
return parts
704713

705714

715+
def concat_zarrs(
716+
urls: Sequence[str],
717+
output: Union[PathType, MutableMapping[str, bytes]],
718+
*,
719+
concat_algorithm: Optional[Literal["xarray_internal"]] = None,
720+
chunk_length: int = 10_000,
721+
chunk_width: int = 1_000,
722+
storage_options: Optional[Dict[str, str]] = None,
723+
dask_fuse_avg_width: int = 50,
724+
) -> None:
725+
"""Concatenate multiple Zarr stores into a single Zarr store.
726+
727+
The Zarr stores are concatenated and rechunked to produce a single combined store.
728+
729+
Parameters
730+
----------
731+
urls
732+
A list of URLs to the Zarr stores to combine, typically the return value of
733+
:func:`vcf_to_zarrs`.
734+
output
735+
Zarr store or path to directory in file system.
736+
concat_algorithm
737+
The algorithm to use to concatenate and rechunk Zarr files. The default None means
738+
use the optimized version suitable for large files, whereas ``xarray_internal`` will
739+
use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745.
740+
chunk_length
741+
Length (number of variants) of chunks in which data are stored, by default 10,000.
742+
This is only used when ``concat_algorithm`` is ``xarray_internal``.
743+
chunk_width
744+
Width (number of samples) to use when storing chunks in output, by default 1,000.
745+
This is only used when ``concat_algorithm`` is ``xarray_internal``.
746+
storage_options
747+
Any additional parameters for the storage backend (see ``fsspec.open``).
748+
dask_fuse_avg_width
749+
Setting for Dask's ``optimization.fuse.ave-width``, see https://github.com/dask/dask/issues/5105
750+
"""
751+
if concat_algorithm == "xarray_internal":
752+
ds = zarrs_to_dataset(urls, chunk_length, chunk_width, storage_options)
753+
754+
with dask.config.set({"optimization.fuse.ave-width": dask_fuse_avg_width}):
755+
ds.to_zarr(output, mode="w")
756+
else:
757+
758+
vars_to_rechunk = []
759+
vars_to_copy = []
760+
storage_options = storage_options or {}
761+
ds = xr.open_zarr( # type: ignore[no-untyped-call]
762+
fsspec.get_mapper(urls[0], **storage_options), concat_characters=False
763+
)
764+
for (var, arr) in ds.data_vars.items():
765+
if arr.dims[0] == "variants":
766+
vars_to_rechunk.append(var)
767+
else:
768+
vars_to_copy.append(var)
769+
770+
concat_zarrs_optimized(urls, output, vars_to_rechunk, vars_to_copy)
771+
772+
706773
def vcf_to_zarr(
707774
input: Union[PathType, Sequence[PathType]],
708775
output: Union[PathType, MutableMapping[str, bytes]],
@@ -723,6 +790,7 @@ def vcf_to_zarr(
723790
fields: Optional[Sequence[str]] = None,
724791
exclude_fields: Optional[Sequence[str]] = None,
725792
field_defs: Optional[Dict[str, Dict[str, Any]]] = None,
793+
concat_algorithm: Optional[Literal["xarray_internal"]] = None,
726794
) -> None:
727795
"""Convert VCF files to a single Zarr on-disk store.
728796
@@ -735,8 +803,7 @@ def vcf_to_zarr(
735803
is None.
736804
737805
For more control over these two steps, consider using :func:`vcf_to_zarrs` followed by
738-
:func:`zarrs_to_dataset`, then saving the dataset using Xarray's
739-
:meth:`xarray.Dataset.to_zarr` method.
806+
:func:`concat_zarrs`.
740807
741808
Parameters
742809
----------
@@ -811,6 +878,10 @@ def vcf_to_zarr(
811878
(which is defined as Number 2 in the VCF header) as ``haplotypes``.
812879
(Note that Number ``A`` is the number of alternate alleles, see section 1.4.2 of the
813880
VCF spec https://samtools.github.io/hts-specs/VCFv4.3.pdf.)
881+
concat_algorithm
882+
The algorithm to use to concatenate and rechunk Zarr files. The default None means
883+
use the optimized version suitable for large files, whereas ``xarray_internal`` will
884+
use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745.
814885
"""
815886

816887
if temp_chunk_length is not None:
@@ -842,6 +913,7 @@ def vcf_to_zarr(
842913
temp_chunk_length=temp_chunk_length,
843914
tempdir=tempdir,
844915
tempdir_storage_options=tempdir_storage_options,
916+
concat_algorithm=concat_algorithm,
845917
)
846918
convert_func(
847919
input, # type: ignore

sgkit/io/vcfzarr_reader.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
import tempfile
22
from pathlib import Path
3-
from typing import Any, Dict, Hashable, List, Optional, Tuple
3+
from typing import (
4+
Any,
5+
Dict,
6+
Hashable,
7+
List,
8+
MutableMapping,
9+
Optional,
10+
Sequence,
11+
Tuple,
12+
Union,
13+
)
414

515
import dask
616
import dask.array as da
17+
import numcodecs
718
import xarray as xr
819
import zarr
920
from fsspec import get_mapper
@@ -151,8 +162,8 @@ def vcfzarr_to_zarr(
151162
ds.to_zarr(output, mode="w")
152163
else:
153164
# Use the optimized algorithm in `concatenate_and_rechunk`
154-
_concat_zarrs_optimized(
155-
zarr_files, output, vars_to_rechunk, vars_to_copy
165+
concat_zarrs_optimized(
166+
zarr_files, output, vars_to_rechunk, vars_to_copy, fix_strings=True
156167
)
157168

158169

@@ -291,18 +302,22 @@ def _get_max_len(zarr_groups: List[zarr.Group], attr_name: str) -> int:
291302
return max_len
292303

293304

294-
def _concat_zarrs_optimized(
295-
zarr_files: List[str],
296-
output: PathType,
305+
def concat_zarrs_optimized(
306+
zarr_files: Sequence[str],
307+
output: Union[PathType, MutableMapping[str, bytes]],
297308
vars_to_rechunk: List[Hashable],
298309
vars_to_copy: List[Hashable],
310+
fix_strings: bool = False,
299311
) -> None:
312+
if isinstance(output, Path):
313+
output = str(output)
314+
300315
zarr_groups = [zarr.open_group(f) for f in zarr_files]
301316

302317
first_zarr_group = zarr_groups[0]
303318

304319
# create the top-level group
305-
zarr.open_group(str(output), mode="w")
320+
zarr.open_group(output, mode="w")
306321

307322
# copy variables that are to be rechunked
308323
# NOTE: that this uses _to_zarr function defined here that is needed to avoid
@@ -311,38 +326,47 @@ def _concat_zarrs_optimized(
311326
delayed = [] # do all the rechunking operations in one computation
312327
for var in vars_to_rechunk:
313328
dtype = None
314-
if var in {"variant_id", "variant_allele"}:
329+
if fix_strings and var in {"variant_id", "variant_allele"}:
315330
max_len = _get_max_len(zarr_groups, f"max_length_{var}")
316331
dtype = f"S{max_len}"
317-
318332
arr = concatenate_and_rechunk(
319333
[group[var] for group in zarr_groups], dtype=dtype
320334
)
335+
336+
_to_zarr_kwargs = dict(fill_value=None)
337+
if not fix_strings and arr.dtype == "O":
338+
# We assume that all object dtypes are variable length strings
339+
_to_zarr_kwargs["object_codec"] = numcodecs.VLenUTF8()
340+
321341
d = _to_zarr( # type: ignore[no-untyped-call]
322342
arr,
323-
str(output),
343+
output,
324344
component=var,
325345
overwrite=True,
326346
compute=False,
327-
fill_value=None,
328347
attrs=first_zarr_group[var].attrs.asdict(),
348+
**_to_zarr_kwargs,
329349
)
330350
delayed.append(d)
331351
da.compute(*delayed)
332352

333353
# copy unchanged variables and top-level metadata
334-
with zarr.open_group(str(output)) as output_zarr:
354+
with zarr.open_group(output) as output_zarr:
335355

336356
# copy variables that are not rechunked (e.g. sample_id)
337357
for var in vars_to_copy:
338358
output_zarr[var] = first_zarr_group[var]
339359
output_zarr[var].attrs.update(first_zarr_group[var].attrs)
340360

341361
# copy top-level attributes
342-
output_zarr.attrs.update(first_zarr_group.attrs)
362+
group_attrs = dict(first_zarr_group.attrs)
363+
if "max_alt_alleles_seen" in group_attrs:
364+
max_alt_alleles_seen = _get_max_len(zarr_groups, "max_alt_alleles_seen")
365+
group_attrs["max_alt_alleles_seen"] = max_alt_alleles_seen
366+
output_zarr.attrs.update(group_attrs)
343367

344368
# consolidate metadata
345-
zarr.consolidate_metadata(str(output))
369+
zarr.consolidate_metadata(output)
346370

347371

348372
def _to_zarr( # type: ignore[no-untyped-def]

sgkit/tests/io/vcf/test_vcf_reader.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,23 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
233233
"is_path",
234234
[True, False],
235235
)
236+
@pytest.mark.parametrize(
237+
"concat_algorithm",
238+
[None, "xarray_internal"],
239+
)
236240
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
237-
def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
241+
def test_vcf_to_zarr__parallel(shared_datadir, is_path, concat_algorithm, tmp_path):
238242
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
239243
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
240244
regions = ["20", "21"]
241245

242-
vcf_to_zarr(path, output, regions=regions, chunk_length=5_000)
246+
vcf_to_zarr(
247+
path,
248+
output,
249+
regions=regions,
250+
chunk_length=5_000,
251+
concat_algorithm=concat_algorithm,
252+
)
243253
ds = xr.open_zarr(output)
244254

245255
assert ds["sample_id"].shape == (1,)
@@ -252,8 +262,12 @@ def test_vcf_to_zarr__parallel(shared_datadir, is_path, tmp_path):
252262
assert ds["variant_id_mask"].shape == (19910,)
253263
assert ds["variant_position"].shape == (19910,)
254264

255-
assert ds["variant_allele"].dtype == "S48"
256-
assert ds["variant_id"].dtype == "S1"
265+
if concat_algorithm is None:
266+
assert ds["variant_allele"].dtype == "O"
267+
assert ds["variant_id"].dtype == "O"
268+
else:
269+
assert ds["variant_allele"].dtype == "S48"
270+
assert ds["variant_id"].dtype == "S1"
257271

258272

259273
@pytest.mark.parametrize(
@@ -305,8 +319,8 @@ def test_vcf_to_zarr__parallel_temp_chunk_length(shared_datadir, is_path, tmp_pa
305319
assert ds["variant_id_mask"].shape == (19910,)
306320
assert ds["variant_position"].shape == (19910,)
307321

308-
assert ds["variant_allele"].dtype == "S48"
309-
assert ds["variant_id"].dtype == "S1"
322+
assert ds["variant_allele"].dtype == "O"
323+
assert ds["variant_id"].dtype == "O"
310324

311325

312326
def test_vcf_to_zarr__parallel_temp_chunk_length_not_divisible(
@@ -531,7 +545,7 @@ def test_vcf_to_zarr__mixed_ploidy_vcf(
531545
)
532546
ds = load_dataset(output)
533547

534-
variant_dtype = "|S1" if regions else "O"
548+
variant_dtype = "O"
535549
assert ds.attrs["contigs"] == ["CHR1", "CHR2", "CHR3"]
536550
assert_array_equal(ds["variant_contig"], [0, 0])
537551
assert_array_equal(ds["variant_position"], [2, 7])
@@ -728,7 +742,7 @@ def test_vcf_to_zarr__parallel_with_fields(shared_datadir, tmp_path):
728742
assert_allclose(ds["variant_MQ"], [58.33, np.nan, 57.45])
729743
assert ds["variant_MQ"].attrs["comment"] == "RMS Mapping Quality"
730744

731-
assert_array_equal(ds["call_PGT"], [[b"0|1"], [b""], [b"0|1"]])
745+
assert_array_equal(ds["call_PGT"], [["0|1"], [""], ["0|1"]])
732746
assert (
733747
ds["call_PGT"].attrs["comment"]
734748
== "Physical phasing haplotype information, describing how the alternate alleles are phased in relation to one another"

sgkit/tests/test_vcfzarr_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_vcfzarr_to_zarr(
140140
consolidated=consolidated,
141141
)
142142

143-
ds = xr.open_zarr(output)
143+
ds = xr.open_zarr(output, concat_characters=False)
144144

145145
# Note that variant_allele values are byte strings, not unicode strings (unlike for read_vcfzarr)
146146
# We should make the two consistent.

0 commit comments

Comments
 (0)