Skip to content

Commit

Permalink
Deal with multiple contigs and sequence lengths
Browse files Browse the repository at this point in the history
Introduces a `contig_id` parameter to variant_data, as described in #949. Fixes #249
  • Loading branch information
hyanwong committed Sep 4, 2024
1 parent a080e8f commit ba629de
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 3 deletions.
115 changes: 113 additions & 2 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from tsinfer import formats


def ts_to_dataset(ts, chunks=None, samples=None):
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
"""
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
Convert the specified tskit tree sequence into an sgkit dataset.
Expand All @@ -61,7 +61,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
genotypes = np.expand_dims(genotypes, axis=2)

ds = sgkit.create_genotype_call_dataset(
variant_contig_names=["1"],
variant_contig_names=["1"] if contigs is None else contigs,
variant_contig=np.zeros(len(tables.sites), dtype=int),
variant_position=tables.sites.position.astype(int),
variant_allele=alleles,
Expand Down Expand Up @@ -292,6 +292,78 @@ def test_simulate_genotype_call_dataset(tmp_path):
assert np.all(v.genotypes == sd_v)


class TestMultiContig:
def make_two_ts_dataset(self, path):
# split ts into 2; put them as different contigs in the same dataset
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
split_at_site = 7
assert ts.num_sites > 10
site_break = ts.site(split_at_site).position
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
variant_contig = ds["variant_contig"][:]
variant_contig[split_at_site:] = 1
ds.update({"variant_contig": variant_contig})
variant_position = ds["variant_position"].values
variant_position[split_at_site:] -= int(site_break)
ds.update({"variant_position": ds["variant_position"]})
ds.update(
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
)
ds.to_zarr(path, mode="w")
return ts1, ts2

def test_unmasked(self, tmp_path):
self.make_two_ts_dataset(tmp_path)
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")

def test_mask(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path,
"variant_ancestral_allele",
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
)
assert np.all(ts2.sites_position == vdata.sites_position)
assert vdata.contig_id == "chr2"

@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
def test_contig_id_param(self, contig_id, tmp_path):
tree_seqs = {}
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path, "variant_ancestral_allele", contig_id=contig_id
)
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
assert vdata.contig_id == contig_id

def test_contig_id_param_and_mask(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path,
"variant_ancestral_allele",
site_mask=np.array(
(ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False]
),
contig_id="chr2",
)
assert np.all(ts2.sites_position[1:] == vdata.sites_position)
assert vdata.contig_id == "chr2"

@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
def test_contig_length(self, contig_id, tmp_path):
tree_seqs = {}
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path, "variant_ancestral_allele", contig_id=contig_id
)
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length


@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
class TestSgkitMask:
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
Expand Down Expand Up @@ -754,3 +826,42 @@ def test_empty_alleles_not_at_end(self, tmp_path):
samples = tsinfer.VariantData(path, "variant_ancestral_allele")
with pytest.raises(ValueError, match="Empty alleles must be at the end"):
tsinfer.infer(samples)

def test_multiple_contigs(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["contig_id"] = (
ds["contig_id"].dims,
np.array(["c10", "c11"], dtype="<U3"),
)
ds["variant_contig"] = (
ds["variant_contig"].dims,
np.array([0, 0, 1], dtype=ds["variant_contig"].dtype),
)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match=r'Sites belong to multiple contigs \("c10", "c11"\)'
):
tsinfer.VariantData(path, ds["variant_allele"][:, 0].astype(str))

def test_bad_contig_param(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match='"XX" not found'):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), contig_id="XX"
)

def test_multiple_contig_param(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["contig_id"] = (
ds["contig_id"].dims,
np.array(["chr1", "chr1"], dtype="<U4"),
)
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match='Multiple contigs named "chr1"'):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), contig_id="chr1"
)
45 changes: 44 additions & 1 deletion tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,6 +2305,7 @@ def __init__(
sample_mask=None,
site_mask=None,
sites_time=None,
contig_id=None,
):
self.path = path
self.data = zarr.open(path, mode="r")
Expand All @@ -2326,6 +2327,20 @@ def __init__(
raise ValueError(
"Site mask array must be the same length as the number of unmasked sites"
)
if contig_id is not None:
contig_index = np.where(self.data.contig_id[:] == contig_id)[0]
if len(contig_index) == 0:
raise ValueError(
f'"{contig_id}" not found among the available contig IDs: '
+ ",".join(f"{n}" for n in self.data.contig_id[:])
)
elif len(contig_index) > 1:
raise ValueError(f'Multiple contigs named "{contig_id}"')
contig_index = contig_index[0]
site_mask = np.logical_or(
site_mask, self.data["variant_contig"][:] != contig_index
)

# We negate the mask as it is much easier in numpy to have True=keep
self.sites_select = ~site_mask.astype(bool)

Expand Down Expand Up @@ -2357,6 +2372,21 @@ def __init__(
" sgkit dataset, indicating that all the genotypes are"
" unphased"
)
self._contig_index = None
self._contig_id = None
contig = self.data.variant_contig[:][self.sites_select]
try:
self._contig_index = contig[0]
self._contig_id = self.data.contig_id[self._contig_index]
except (IndexError, AttributeError):
pass
if self._contig_index is not None and np.any(contig != self._contig_index):
ctigs = ", ".join(f'"{self.data.contig_id[c]}"' for c in np.unique(contig))
raise ValueError(
f"Sites belong to multiple contigs ({ctigs}). Please restrict "
"sites to one contig e.g. via the `contig_id` argument."
)

if np.any(np.diff(self.sites_position) <= 0):
raise ValueError(
"Values taken from the variant_position array are not strictly "
Expand Down Expand Up @@ -2460,7 +2490,20 @@ def sequence_length(self):
try:
return self.data.attrs["sequence_length"]
except KeyError:
return int(np.max(self.data["variant_position"])) + 1
if self._contig_index is not None:
try:
return self.data.contig_length[self._contig_index]
except AttributeError:
pass
return int(np.max(self.data["variant_position"])) + 1

@property
def contig_id(self):
"""
The contig ID (name) for all used sites, or None if no
contig IDs were provided
"""
return self._contig_id

@property
def num_sites(self):
Expand Down

0 comments on commit ba629de

Please sign in to comment.