Skip to content

Commit

Permalink
Deal with multiple contigs and sequence lengths
Browse files Browse the repository at this point in the history
Introduces a `contig_id` parameter to variant_data, as described in #949. Fixes #249
  • Loading branch information
hyanwong committed Sep 10, 2024
1 parent 7f23758 commit f135465
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
- Document that the `zarr-vcf` dataset can be either a path or an in-memory zarr group.
(feature introduced in {pr}`966`, documented in {pr}`974`, {user}`hyanwong`)

- Allow a contig to be selected by name (`contig_id`), and get the `sequence_length`
of the contig associated with the unmasked sites, if contig lengths are provided
({pr}`964`, {user}`hyanwong`)

**Fixes**

- Properly account for "N" as an unknown ancestral state, and ban "" from being
Expand Down
9 changes: 5 additions & 4 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ onto branches by {meth}`parsimony<tskit.Tree.map_mutations>`.
It is also possible to *completely* exclude sites and samples, by specifing a boolean
`site_mask` and/or a `sample_mask` when creating the `VariantData` object. Sites or samples with
a mask value of `True` will be completely omitted both from inference and the final tree sequence.
This can be useful, for example, if your VCF file contains multiple chromosomes (in which case
`tsinfer` will need to be run separately on each chromosome) or if you wish to select only a subset
of the chromosome for inference (e.g. to reduce computational load). If a `site_mask` is provided,
note that the ancestral alleles array only specifies alleles for the unmasked sites.
This can be useful, for example, if you wish to select only a subset of the chromosome for
inference, e.g. to reduce computational load. You can also use it to subset inference to a
particular contig, if your dataset contains multiple contigs (although this can be more easily
done using the `contig_id` parameter). Note that if a `site_mask` is provided,
the ancestral states array should only specify alleles for the unmasked sites.

Below, for instance, is an example of including only sites up to position six in the contig
labelled "chr1" in the `example_data.vcz` file:
Expand Down
128 changes: 125 additions & 3 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from tsinfer import formats


def ts_to_dataset(ts, chunks=None, samples=None):
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
"""
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
Convert the specified tskit tree sequence into an sgkit dataset.
Expand All @@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
genotypes = np.expand_dims(genotypes, axis=2)

ds = sgkit.create_genotype_call_dataset(
variant_contig_names=["1"],
variant_contig_names=["1"] if contigs is None else contigs,
variant_contig=np.zeros(len(tables.sites), dtype=int),
variant_position=tables.sites.position.astype(int),
variant_allele=alleles,
Expand Down Expand Up @@ -289,9 +289,83 @@ def test_simulate_genotype_call_dataset(tmp_path):
assert np.all(v.genotypes == sd_v)


class TestMultiContig:
def make_two_ts_dataset(self, path):
# split ts into 2; put them as different contigs in the same dataset
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
split_at_site = 7
assert ts.num_sites > 10
site_break = ts.site(split_at_site).position
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
variant_contig = ds["variant_contig"][:]
variant_contig[split_at_site:] = 1
ds.update({"variant_contig": variant_contig})
variant_position = ds["variant_position"].values
variant_position[split_at_site:] -= int(site_break)
ds.update({"variant_position": ds["variant_position"]})
ds.update(
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
)
ds.to_zarr(path, mode="w")
return ts1, ts2

def test_unmasked(self, tmp_path):
self.make_two_ts_dataset(tmp_path)
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")

def test_mask(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path,
"variant_ancestral_allele",
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
)
assert np.all(ts2.sites_position == vdata.sites_position)
assert vdata.contig_id == "chr2"
assert vdata.sequence_length == ts2.sequence_length

@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
def test_contig_id_param(self, contig_id, tmp_path):
tree_seqs = {}
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path, "variant_ancestral_allele", contig_id=contig_id
)
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
assert vdata.contig_id == contig_id
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length

def test_contig_id_param_and_mask(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path,
"variant_ancestral_allele",
site_mask=np.array(
(ts1.num_sites + 1) * [True] + (ts2.num_sites - 1) * [False]
),
contig_id="chr2",
)
assert np.all(ts2.sites_position[1:] == vdata.sites_position)
assert vdata.contig_id == "chr2"

@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
def test_contig_length(self, contig_id, tmp_path):
tree_seqs = {}
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path, "variant_ancestral_allele", contig_id=contig_id
)
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length


@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
class TestSgkitMask:
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0]])
def test_sgkit_variant_mask(self, tmp_path, sites):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
Expand Down Expand Up @@ -831,3 +905,51 @@ def test_unimplemented_from_tree_sequence(self):
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
with pytest.raises(NotImplementedError):
tsinfer.VariantData.from_tree_sequence(None)

def test_multiple_contigs(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["contig_id"] = (
ds["contig_id"].dims,
np.array(["c10", "c11"], dtype="<U3"),
)
ds["variant_contig"] = (
ds["variant_contig"].dims,
np.array([0, 0, 1], dtype=ds["variant_contig"].dtype),
)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match=r'Sites belong to multiple contigs \("c10", "c11"\)'
):
tsinfer.VariantData(path, ds["variant_allele"][:, 0].astype(str))

def test_all_masked(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="All sites have been masked out"):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
)

def test_bad_contig_param(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match='"XX" not found'):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), contig_id="XX"
)

def test_multiple_contig_param(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["contig_id"] = (
ds["contig_id"].dims,
np.array(["chr1", "chr1"], dtype="<U4"),
)
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match='Multiple contigs named "chr1"'):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), contig_id="chr1"
)
61 changes: 59 additions & 2 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,7 +2326,8 @@ class VariantData(SampleData):
:param Union(array, str) site_mask: A numpy array of booleans specifying which
sites to mask out (exclude) from the dataset. Alternatively, a string
can be provided, giving the name of an array in the input dataset which contains
the site mask. If ``None`` (default), all sites are included.
the site mask. If ``None`` (default), all sites are included (unless restricted
to a ``contig_id``, see below).
:param Union(array, str) sites_time: A numpy array of floats specifying the relative
time of occurrence of the mutation to the derived state at each site. This must
be of the same length as the number of unmasked sites. Alternatively, a
Expand All @@ -2336,6 +2337,11 @@ class VariantData(SampleData):
reasonable approximation to the relative order of ancestors used for inference.
Time values are ignored for sites not used in inference, such as singletons,
sites with more than two alleles, or sites with an unknown ancestral state.
:param str contig_id: The name of the contig to use (e.g. "chr1"), if the .vcz file
contains multiple contigs; contig names can be found in the `.contig_id array
of the input dataset. If provided, sites associated with any other contigs will
be added to the sites that are masked out. If ``None`` (default), do not mark
out sites on the basis of their contig ID.
"""

FORMAT_NAME = "tsinfer-variant-data"
Expand All @@ -2349,6 +2355,7 @@ def __init__(
sample_mask=None,
site_mask=None,
sites_time=None,
contig_id=None,
):
try:
if len(path_or_zarr.call_genotype.shape) == 3:
Expand Down Expand Up @@ -2382,8 +2389,24 @@ def __init__(
raise ValueError(
"Site mask array must be the same length as the number of unmasked sites"
)
if contig_id is not None:
contig_index = np.where(self.data.contig_id[:] == contig_id)[0]
if len(contig_index) == 0:
raise ValueError(
f'"{contig_id}" not found among the available contig IDs: '
+ ",".join(f"{n}" for n in self.data.contig_id[:])
)
elif len(contig_index) > 1:
raise ValueError(f'Multiple contigs named "{contig_id}"')
contig_index = contig_index[0]
site_mask = np.logical_or(
site_mask, self.data["variant_contig"][:] != contig_index
)

# We negate the mask as it is much easier in numpy to have True=keep
self.sites_select = ~site_mask.astype(bool)
if np.sum(self.sites_select) == 0:
raise ValueError("All sites have been masked out. Please unmask some")

if sample_mask is None:
sample_mask = np.full(self._num_individuals_before_mask, False, dtype=bool)
Expand Down Expand Up @@ -2413,6 +2436,20 @@ def __init__(
" zarr dataset, indicating that all the genotypes are"
" unphased"
)

used_contigs = self.data.variant_contig[:][self.sites_select]
self._contig_index = used_contigs[0]
self._contig_id = self.data.contig_id[self._contig_index]

if np.any(used_contigs != self._contig_index):
contig_names = ", ".join(
f'"{self.data.contig_id[c]}"' for c in np.unique(used_contigs)
)
raise ValueError(
f"Sites belong to multiple contigs ({contig_names}). Please restrict "
"sites to one contig e.g. via the `contig_id` argument."
)

if np.any(np.diff(self.sites_position) <= 0):
raise ValueError(
"Values taken from the variant_position array are not strictly "
Expand Down Expand Up @@ -2517,10 +2554,30 @@ def finalised(self):

@functools.cached_property
def sequence_length(self):
"""
The sequence length of the contig associated with sites used in the dataset.
If the dataset has a "sequence_length" attribute, this is always used, otherwise
if the dataset has recorded contig lengths, the appropriate length is taken,
otherwise the length is calculated from the maximum variant position plus one.
"""
try:
return self.data.attrs["sequence_length"]
except KeyError:
return int(np.max(self.data["variant_position"])) + 1
if self._contig_index is not None:
try:
if self._contig_index < len(self.data.contig_length):
return self.data.contig_length[self._contig_index]
except AttributeError:
pass # contig_length is optional, fall back to calculating length
return int(np.max(self.data["variant_position"])) + 1

@property
def contig_id(self):
"""
The contig ID (name) for all used sites, or None if no
contig IDs were provided
"""
return self._contig_id

@property
def num_sites(self):
Expand Down

0 comments on commit f135465

Please sign in to comment.