Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensure FileBasedVariantLookup is used as a context manager #71

Merged
merged 10 commits into from
Nov 13, 2024
55 changes: 51 additions & 4 deletions prymer/api/variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@
import logging
from abc import ABC
from abc import abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass
from dataclasses import field
from enum import auto
from enum import unique
from pathlib import Path
from types import TracebackType
from typing import Optional
from typing import Self
clintval marked this conversation as resolved.
Show resolved Hide resolved
from typing import final

import pysam
Expand Down Expand Up @@ -320,10 +323,20 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Subclasses must implement this method."""


class FileBasedVariantLookup(VariantLookup):
"""Implementation of VariantLookup that queries against indexed VCF files each time a query is
class FileBasedVariantLookup(VariantLookup, AbstractContextManager):
clintval marked this conversation as resolved.
Show resolved Hide resolved
"""Implementation of `VariantLookup` that queries against indexed VCF files each time a query is
performed. Assumes the index is located adjacent to the VCF file and has the same base name with
either a .csi or .tbi suffix."""
either a .csi or .tbi suffix.

Example:

```python
>>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]

```
""" # noqa: E501

def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_missing_mafs: bool):
self._readers: list[VariantFile] = []
Expand Down Expand Up @@ -353,6 +366,26 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
simple_variants.extend(self.to_variants(variants, source_vcf=path))
return sorted(simple_variants, key=lambda x: x.pos)

def __enter__(self) -> Self:
"""Enter this context manager."""
super().__enter__()
return self
clintval marked this conversation as resolved.
Show resolved Hide resolved

def close(self) -> None:
"""Close the underlying VCF file handles."""
for handle in self._readers:
handle.close()
clintval marked this conversation as resolved.
Show resolved Hide resolved

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit this context manager and close all underlying VCF handles."""
super().__exit__(exc_type, exc_value, traceback)
self.close()
clintval marked this conversation as resolved.
Show resolved Hide resolved


class VariantOverlapDetector(VariantLookup):
"""Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting
Expand Down Expand Up @@ -443,7 +476,21 @@ def disk_based(
vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False
) -> FileBasedVariantLookup:
"""Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup.
Appropriate for large VCFs."""

Appropriate for large VCFs. Ensure that you take advantage of [`contextlib.closing`](https://docs.python.org/3/library/contextlib.html#contextlib.closing)
for automatically closing the file-base variant lookup after it is used. See below for an
clintval marked this conversation as resolved.
Show resolved Hide resolved
example.

Example:

```python
>>> from contextlib import closing
>>> with closing(disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0)) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]

```
""" # noqa: E501
return FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs
)
106 changes: 60 additions & 46 deletions tests/api/test_variant_lookup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import random
from contextlib import closing
from dataclasses import dataclass
from dataclasses import replace
from pathlib import Path
from typing import Optional
from typing import Type

import fgpyo.vcf.builder
import pytest
Expand All @@ -17,7 +17,6 @@
from prymer.api.span import Strand
from prymer.api.variant_lookup import FileBasedVariantLookup
from prymer.api.variant_lookup import SimpleVariant
from prymer.api.variant_lookup import VariantLookup
from prymer.api.variant_lookup import VariantOverlapDetector
from prymer.api.variant_lookup import VariantType
from prymer.api.variant_lookup import cached
Expand Down Expand Up @@ -435,13 +434,24 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor
assert actual_simple_variants == VALID_SIMPLE_VARIANTS_APPROX


@pytest.mark.parametrize("variant_lookup_class", [FileBasedVariantLookup, VariantOverlapDetector])
def test_simple_variant_conversion_logs(
variant_lookup_class: Type[VariantLookup], vcf_path: Path, caplog: pytest.LogCaptureFixture
def test_simple_variant_conversion_logs_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = variant_lookup_class(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
) as variant_lookup:
variant_lookup.query(refname="foo", start=1, end=2)
assert "No variants extracted from region of interest" in caplog.text


def test_simple_variant_conversion_logs_non_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = VariantOverlapDetector(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
)
variant_lookup.query(refname="foo", start=1, end=2)
Expand All @@ -451,15 +461,19 @@ def test_simple_variant_conversion_logs(
def test_missing_index_file_raises(temp_missing_path: Path) -> None:
"""Test that both VariantLookup objects raise an error with a missing index file."""
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)
with closing(
disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)
):
pass
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)


def test_missing_vcf_files_raises() -> None:
"""Test that an error is raised when no VCF_paths are provided."""
with pytest.raises(ValueError, match="No VCF paths given to query"):
disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)
with closing(disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)):
pass
with pytest.raises(ValueError, match="No VCF paths given to query"):
cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)

Expand All @@ -480,17 +494,17 @@ def test_vcf_header_missing_chrom(
caplog.set_level(logging.DEBUG)
vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf]
random.Random(random_seed).shuffle(vcf_paths)
variant_lookup = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True
)
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
# Should find all 12 variants from vcf_path (no filtering), with two variants having two
# alternate alleles
assert len(variants_of_interest) == 14
expected_error_msg = "does not contain chromosome"
assert expected_error_msg in caplog.text
) as variant_lookup:
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
# Should find all 12 variants from vcf_path (no filtering), with two variants having two
# alternate alleles
assert len(variants_of_interest) == 14
expected_error_msg = "does not contain chromosome"
assert expected_error_msg in caplog.text
clintval marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("test_case", VALID_SIMPLE_VARIANT_TEST_CASES)
Expand Down Expand Up @@ -587,32 +601,32 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs:
@pytest.mark.parametrize("include_missing_mafs", [False, True])
def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None:
"""Test that `FileBasedVariantLookup.query()` MAF filtering is as expected."""
file_based_vcf_query = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs
)
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]

if not include_missing_mafs:
assert query == get_simple_variant_approx_by_id(
"common-multiallelic-1/2",
"common-multiallelic-2/2",
"common-mixed-1/2",
"common-mixed-2/2",
)
else:
assert query == get_simple_variant_approx_by_id(
"complex-variant-sv-1/1",
"common-multiallelic-1/2",
"common-multiallelic-2/2",
"common-mixed-1/2",
"common-mixed-2/2",
)
) as file_based_vcf_query:
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]

if not include_missing_mafs:
assert query == get_simple_variant_approx_by_id(
"common-multiallelic-1/2",
"common-multiallelic-2/2",
"common-mixed-1/2",
"common-mixed-2/2",
)
else:
assert query == get_simple_variant_approx_by_id(
"complex-variant-sv-1/1",
"common-multiallelic-1/2",
"common-multiallelic-2/2",
"common-mixed-1/2",
"common-mixed-2/2",
)
Loading