Skip to content

Commit 28868ef

Browse files
tomwhitemergify[bot]
authored andcommitted
Test zarr compressor and filters are correct in parallel case
1 parent b8326aa commit 28868ef

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

sgkit/io/vcfzarr_reader.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,10 +333,20 @@ def concat_zarrs_optimized(
333333
[group[var] for group in zarr_groups], dtype=dtype
334334
)
335335

336-
_to_zarr_kwargs = dict(fill_value=None)
336+
_to_zarr_kwargs = dict(
337+
compressor=first_zarr_group[var].compressor,
338+
filters=first_zarr_group[var].filters,
339+
fill_value=None,
340+
)
337341
if not fix_strings and arr.dtype == "O":
338342
# We assume that all object dtypes are variable length strings
339-
_to_zarr_kwargs["object_codec"] = numcodecs.VLenUTF8()
343+
var_len_str_codec = numcodecs.VLenUTF8()
344+
_to_zarr_kwargs["object_codec"] = var_len_str_codec
345+
# Remove from filters to avoid double encoding error
346+
if var_len_str_codec in first_zarr_group[var].filters:
347+
filters = list(first_zarr_group[var].filters)
348+
filters.remove(var_len_str_codec)
349+
_to_zarr_kwargs["filters"] = filters
340350

341351
d = _to_zarr( # type: ignore[no-untyped-call]
342352
arr,

sgkit/tests/io/vcf/test_vcf_reader.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,43 @@ def test_vcf_to_zarr__compressor_and_filters(shared_datadir, is_path, tmp_path):
229229
assert z["variant_id_mask"].filters is None
230230

231231

232+
@pytest.mark.parametrize(
233+
"is_path",
234+
[True, False],
235+
)
236+
@pytest.mark.filterwarnings("ignore::sgkit.io.vcf.MaxAltAllelesExceededWarning")
237+
def test_vcf_to_zarr__parallel_compressor_and_filters(
238+
shared_datadir, is_path, tmp_path
239+
):
240+
path = path_for_test(shared_datadir, "CEUTrio.20.21.gatk3.4.g.vcf.bgz", is_path)
241+
output = tmp_path.joinpath("vcf_concat.zarr").as_posix()
242+
regions = ["20", "21"]
243+
244+
default_compressor = Blosc("zlib", 1, Blosc.NOSHUFFLE)
245+
variant_id_compressor = Blosc("zlib", 2, Blosc.NOSHUFFLE)
246+
encoding = dict(
247+
variant_id=dict(compressor=variant_id_compressor),
248+
variant_id_mask=dict(filters=None),
249+
)
250+
vcf_to_zarr(
251+
path,
252+
output,
253+
regions=regions,
254+
chunk_length=5_000,
255+
compressor=default_compressor,
256+
encoding=encoding,
257+
)
258+
259+
# look at actual Zarr store to check compressor and filters
260+
z = zarr.open(output)
261+
assert z["call_genotype"].compressor == default_compressor
262+
assert z["call_genotype"].filters is None
263+
assert z["call_genotype_mask"].filters == [PackBits()]
264+
265+
assert z["variant_id"].compressor == variant_id_compressor
266+
assert z["variant_id_mask"].filters is None
267+
268+
232269
@pytest.mark.parametrize(
233270
"is_path",
234271
[True, False],

0 commit comments

Comments
 (0)