From e2d71aa29f5cdab91aa3fa21d55a04196bd0caed Mon Sep 17 00:00:00 2001 From: Erin McAuley Date: Wed, 27 Nov 2024 09:04:41 -0500 Subject: [PATCH] feat: public VariantLookup class --- prymer/api/__init__.py | 12 +- prymer/api/variant_lookup.py | 174 +++++++++++++--------- prymer/primer3/primer3.py | 37 +++-- tests/api/test_variant_lookup.py | 50 ++++--- tests/primer3/test_primer3.py | 240 ++++++++++++++++++++----------- 5 files changed, 319 insertions(+), 194 deletions(-) diff --git a/prymer/api/__init__.py b/prymer/api/__init__.py index 4d937a2..3720df4 100644 --- a/prymer/api/__init__.py +++ b/prymer/api/__init__.py @@ -8,13 +8,13 @@ from prymer.api.span import BedLikeCoords from prymer.api.span import Span from prymer.api.span import Strand -from prymer.api.variant_lookup import FileBasedVariantLookup from prymer.api.variant_lookup import SimpleVariant + +# from prymer.api.variant_lookup import _DiskBasedLookup +# from prymer.api.variant_lookup import _InMemoryLookup +# from prymer.api.variant_lookup import _VariantLookup from prymer.api.variant_lookup import VariantLookup -from prymer.api.variant_lookup import VariantOverlapDetector from prymer.api.variant_lookup import VariantType -from prymer.api.variant_lookup import cached -from prymer.api.variant_lookup import disk_based __all__ = [ "ClusteredIntervals", @@ -30,8 +30,4 @@ "VariantType", "SimpleVariant", "VariantLookup", - "FileBasedVariantLookup", - "VariantOverlapDetector", - "cached", - "disk_based", ] diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index a08a6a4..ce4944c 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -7,20 +7,12 @@ [`query()`][prymer.api.variant_lookup.VariantLookup.query] method for retrieving variants that overlap the given range. -Two concrete implementations are provided that both take a list of VCF files to be queried: - -- [`FileBasedVariantLookup`][prymer.api.variant_lookup.FileBasedVariantLookup] -- performs -disk-based retrieval of variants (using a VCF index). This class is recommended for large VCFs. The -[`disk_based()`][prymer.api.variant_lookup.disk_based] alternative constructor is -provided for easy construction of this object. -- [`VariantOverlapDetector`][prymer.api.variant_lookup.VariantOverlapDetector] -- reads in -variants into memory and uses an +[`VariantLookup`][prymer.api.variant_lookup.VariantLookup] needs a list of VCF files to be queried, +a minimum Minimum Allele Frequency with which to optionally filter variants, a boolean `include_missing_mafs` flag, and a +boolean `cached` flag. It is recommended to set `cached` to `False` for large VCFs. When `cached` +is True, variants from smaller VCFs are loaded into memory. [`pybedlite.overlap_detector.OverlapDetector`](https://pybedlite.readthedocs.io/en/latest/api.html#pybedlite.overlap_detector.OverlapDetector) -for querying. This class is recommended for small VCFs. The -[`cached()`][prymer.api.variant_lookup.cached] alternative constructor is provided for -easy construction of this object. - -Each class can also use minor allele frequency (MAF) to filter variants. +is used for querying. The helper class `SimpleVariant` is included to facilitate VCF querying and reporting out results. @@ -28,7 +20,7 @@ ```python >>> from pathlib import Path ->>> lookup = cached(vcf_paths=[Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.00, include_missing_mafs=True) +>>> lookup = VariantLookup(vcf_paths=[Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.00, include_missing_mafs=True) >>> lookup.query(refname="chr2", start=7999, end=8000) [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] >>> variants = lookup.query(refname="chr2", start=7999, end=9900) @@ -67,7 +59,6 @@ from types import TracebackType from typing import ContextManager from typing import Optional -from typing import final import pysam from fgpyo.vcf import reader @@ -237,7 +228,7 @@ def build(simple_variant: SimpleVariant) -> "_VariantInterval": ) -class VariantLookup(ABC): +class _VariantLookup(ABC): """Base class to represent a variant from a given genomic range. Attributes: @@ -258,7 +249,11 @@ def __init__( self.min_maf: Optional[float] = min_maf self.include_missing_mafs: bool = include_missing_mafs - @final + @abstractmethod + def close(self) -> None: + pass + + @abstractmethod def query( self, refname: str, @@ -277,20 +272,7 @@ def query( include_missing_mafs: whether to include variants with a missing MAF (overrides self.include_missing_mafs) """ - if maf is None: - maf = self.min_maf - if include_missing_mafs is None: - include_missing_mafs = self.include_missing_mafs - - variants = self._query(refname=refname, start=start, end=end) - if len(variants) == 0: - _logger.debug(f"No variants extracted from region of interest: {refname}:{start}-{end}") - if maf is None or maf <= 0.0: - return variants - elif include_missing_mafs: # return variants with a MAF above threshold or missing - return [v for v in variants if (v.maf is None or v.maf >= maf)] - else: - return [v for v in variants if v.maf is not None and v.maf >= maf] + pass @staticmethod def to_variants( @@ -319,12 +301,76 @@ def to_variants( simple_vars.extend(simple_variants) return sorted(simple_vars, key=lambda v: (v.pos, v.id)) - @abstractmethod - def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: - """Subclasses must implement this method.""" + +class VariantLookup(ContextManager): + def __init__( + self, + vcf_paths: list[Path], + min_maf: Optional[float], + include_missing_mafs: bool, + cached: bool = True, + ): + self.vcf_paths: list[Path] = vcf_paths + self.min_maf: Optional[float] = min_maf + self.include_missing_mafs: bool = include_missing_mafs + self.cached: bool = cached + + self._lookup: _VariantLookup + + if cached: + self._lookup = _InMemoryLookup( + vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs + ) + else: + self._lookup = _DiskBasedLookup( + vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs + ) + + def __enter__(self) -> "VariantLookup": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self._lookup.close() + + def query( + self, + refname: str, + start: int, + end: int, + maf: Optional[float] = None, + include_missing_mafs: bool = None, + ) -> list[SimpleVariant]: + + maf = maf if maf is not None else self.min_maf + include_missing_mafs = ( + include_missing_mafs if include_missing_mafs is not None else self.include_missing_mafs + ) + variants: list[SimpleVariant] = self._lookup.query( + refname=refname, + start=start, + end=end, + ) + + if len(variants) == 0: + _logger.debug(f"No variants extracted from region of interest: {refname}:{start}-{end}") + if maf is None or maf <= 0.0: + return variants + elif include_missing_mafs: # return variants with a MAF above threshold or missing + return [v for v in variants if (v.maf is None or v.maf >= maf)] + else: + return [v for v in variants if v.maf is not None and v.maf >= maf] + + def close(self) -> None: + """Close the underlying VCF file handles.""" + self._lookup.close() -class FileBasedVariantLookup(ContextManager, VariantLookup): +class _DiskBasedLookup(_VariantLookup): """Implementation of `VariantLookup` that queries against indexed VCF files each time a query is performed. Assumes the index is located adjacent to the VCF file and has the same base name with either a .csi or .tbi suffix. @@ -332,7 +378,7 @@ class FileBasedVariantLookup(ContextManager, VariantLookup): Example: ```python - >>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup: + >>> with _DiskBasedLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup: ... lookup.query(refname="chr2", start=7999, end=8000) [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] @@ -355,7 +401,7 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss open_fh = pysam.VariantFile(str(path)) self._readers.append(open_fh) - def __enter__(self) -> "FileBasedVariantLookup": + def __enter__(self) -> "_DiskBasedLookup": """Enter the context manager.""" return self @@ -369,7 +415,14 @@ def __exit__( self.close() return None - def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: + def query( + self, + refname: str, + start: int, + end: int, + maf: Optional[float] = None, + include_missing_mafs: bool = None, + ) -> list[SimpleVariant]: """Queries variants from the VCFs used by this lookup and returns a `SimpleVariant`.""" simple_variants: list[SimpleVariant] = [] for fh, path in zip(self._readers, self.vcf_paths, strict=True): @@ -387,7 +440,7 @@ def close(self) -> None: handle.close() -class VariantOverlapDetector(VariantLookup): +class _InMemoryLookup(_VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting Variants into an `OverlapDetector`.""" @@ -415,7 +468,14 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss ) self._overlap_detector.add_all(variant_intervals) - def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: + def query( + self, + refname: str, + start: int, + end: int, + maf: Optional[float] = None, + include_missing_mafs: bool = None, + ) -> list[SimpleVariant]: """Queries variants from the VCFs used by this lookup.""" query = Interval( refname=refname, start=start - 1, end=end @@ -426,6 +486,9 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: ) return sorted(overlapping_variants, key=lambda v: (v.pos, v.id)) + def close(self) -> None: + pass + # module-level functions def calc_maf_from_filter(variant: pysam.VariantRecord) -> Optional[float]: @@ -460,34 +523,3 @@ def calc_maf_from_filter(variant: pysam.VariantRecord) -> Optional[float]: maf = num_alt / len(gts) return maf - - -def cached( - vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False -) -> VariantOverlapDetector: - """Constructs a `VariantLookup` that caches all variants in memory for fast lookup. - Appropriate for small VCFs.""" - return VariantOverlapDetector( - vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs - ) - - -def disk_based( - vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False -) -> FileBasedVariantLookup: - """Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup. - - Appropriate for large VCFs. - - Example: - - ```python - >>> with disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0) as lookup: - ... lookup.query(refname="chr2", start=7999, end=8000) - [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] - - ``` - """ # noqa: E501 - return FileBasedVariantLookup( - vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs - ) diff --git a/prymer/primer3/primer3.py b/prymer/primer3/primer3.py index 7dd2b92..241bb0c 100644 --- a/prymer/primer3/primer3.py +++ b/prymer/primer3/primer3.py @@ -17,15 +17,15 @@ nucleotide sequences can be retrieved. The full path to the `primer3` executable can provided, otherwise it is assumed to be on the PATH. Furthermore, optionally a [`VariantLookup`][prymer.api.variant_lookup.VariantLookup] may be provided to -hard-mask the design and target regions as to avoid design primers over polymorphic sites. +hard-mask the design and target regions so as to avoid design primers over polymorphic sites. ```python >>> from pathlib import Path ->>> from prymer.api.variant_lookup import VariantLookup, VariantOverlapDetector +>>> from prymer.api.variant_lookup import _VariantLookup, _InMemoryLookup >>> genome_fasta = Path("./tests/primer3/data/miniref.fa") >>> genome_vcf = Path("./tests/primer3/data/miniref.variants.vcf.gz") ->>> variant_lookup: VariantLookup = VariantOverlapDetector(vcf_paths=[genome_vcf], min_maf=0.01, include_missing_mafs=False) ->>> designer = Primer3(genome_fasta=genome_fasta, variant_lookup=variant_lookup) +>>> variant_lookup: VariantLookup = VariantLookup(vcf_paths=[genome_vcf], min_maf=0.01, include_missing_mafs=False) +>>> designer = Primer3(genome_fasta=genome_fasta, variants=variant_lookup) ``` @@ -229,13 +229,14 @@ def __init__( self, genome_fasta: Path, executable: Optional[str] = None, - variant_lookup: Optional[VariantLookup] = None, + variants: list[Path] | VariantLookup | None = None, ) -> None: """ Args: genome_fasta: Path to reference genome .fasta file executable: string representation of the path to primer3_core - variant_lookup: VariantLookup object to facilitate hard-masking variants + variants: an optional list of VCF `Path`s or VariantLookup object to facilitate + hard-masking variants Assumes the sequence dictionary is located adjacent to the .fasta file and has the same base name with a .dict suffix. @@ -246,12 +247,30 @@ def __init__( ) command: list[str] = [f"{executable_path}"] - self.variant_lookup = variant_lookup + self.variant_lookup: Optional[VariantLookup] + # If no variants are given, or they are in a pre-constructed `VariantLookup` object, + # set the object attribute + if variants is None or isinstance(variants, VariantLookup): + self.variant_lookup = variants + # If `variants` is a list of `Path`s to indexed VCF files, create a `VariantLookup` object + # with default settings (no MAF filtering, do not include variants with missing MAFs) + elif isinstance(variants, list) and all( + isinstance(path, Path) for path in variants + ): # if provided path, create appropriate `VariantLookup` + self.variant_lookup = VariantLookup( + vcf_paths=variants, + cached=True, + min_maf=0.00, + include_missing_mafs=False, + ) + else: + raise ValueError( + "Variant lookup is required to be a list of `Path` objects or a" + f" pre-constructed `VariantLookup` object: received {variants}" + ) self._fasta = pysam.FastaFile(filename=f"{genome_fasta}") dict_path = genome_fasta.with_suffix(".dict") - # TODO: This is a placeholder while waiting for #160 to be resolved - # https://github.com/fulcrumgenomics/fgpyo/pull/160 with reader(dict_path, file_type=sam.SamFileType.SAM) as fh: self._dict: SequenceDictionary = SequenceDictionary.from_sam(header=fh.header) diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index 164ee16..0f6d38f 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -14,13 +14,13 @@ from prymer.api.span import Span from prymer.api.span import Strand -from prymer.api.variant_lookup import FileBasedVariantLookup from prymer.api.variant_lookup import SimpleVariant -from prymer.api.variant_lookup import VariantOverlapDetector +from prymer.api.variant_lookup import VariantLookup from prymer.api.variant_lookup import VariantType -from prymer.api.variant_lookup import cached +from prymer.api.variant_lookup import _DiskBasedLookup +from prymer.api.variant_lookup import _InMemoryLookup +from prymer.api.variant_lookup import _VariantLookup from prymer.api.variant_lookup import calc_maf_from_filter -from prymer.api.variant_lookup import disk_based @pytest.mark.parametrize( @@ -380,7 +380,7 @@ def get_simple_variant_approx_by_id(*variant_id: str) -> list[SimpleVariant]: def variant_overlap_detector_query( - detector: VariantOverlapDetector, + detector: VariantLookup | _VariantLookup, refname: str, start: int, end: int, @@ -423,7 +423,7 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor which is required for class instantiation. We use the in-memory `VariantBuilder` here to keep test data consistent.""" - variant_overlap_detector = VariantOverlapDetector( + variant_overlap_detector = _InMemoryLookup( vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=True ) # overcome rounding differences @@ -438,10 +438,13 @@ def test_simple_variant_conversion_logs_file_based( ) -> None: """Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert.""" caplog.set_level(logging.DEBUG) - with FileBasedVariantLookup( - vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False + with VariantLookup( + vcf_paths=[vcf_path], + min_maf=0.01, + include_missing_mafs=False, + cached=False, ) as variant_lookup: - variant_lookup.query(refname="foo", start=1, end=2) + variant_lookup.query(refname="chr2", start=1, end=2) assert "No variants extracted from region of interest" in caplog.text @@ -450,29 +453,29 @@ def test_simple_variant_conversion_logs_non_file_based( ) -> None: """Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert.""" caplog.set_level(logging.DEBUG) - variant_lookup = VariantOverlapDetector( - vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False + variant_lookup = VariantLookup( + vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False, cached=False ) - variant_lookup.query(refname="foo", start=1, end=2) + variant_lookup.query(refname="chr2", start=1, end=2) assert "No variants extracted from region of interest" in caplog.text def test_missing_index_file_raises(temp_missing_path: Path) -> None: """Test that both VariantLookup objects raise an error with a missing index file.""" with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): - with disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False): + with VariantLookup(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False): pass with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): - cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) + VariantLookup(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) def test_missing_vcf_files_raises() -> None: """Test that an error is raised when no VCF_paths are provided.""" with pytest.raises(ValueError, match="No VCF paths given to query"): - with disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False): + with VariantLookup(vcf_paths=[], min_maf=0.01, include_missing_mafs=False): pass with pytest.raises(ValueError, match="No VCF paths given to query"): - cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) + VariantLookup(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) @pytest.mark.parametrize("random_seed", [1, 10, 100, 1000, 10000]) @@ -491,7 +494,7 @@ def test_vcf_header_missing_chrom( caplog.set_level(logging.DEBUG) vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf] random.Random(random_seed).shuffle(vcf_paths) - with FileBasedVariantLookup( + with _DiskBasedLookup( vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True ) as variant_lookup: variants_of_interest = variant_lookup.query( @@ -534,7 +537,7 @@ def test_calc_maf_from_gt_only() -> None: def test_variant_overlap_detector_query(vcf_path: Path) -> None: """Test `VariantOverlapDetector.query()` positional filtering.""" - variant_overlap_detector = VariantOverlapDetector( + variant_overlap_detector = _InMemoryLookup( vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=True ) @@ -566,12 +569,12 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None: @pytest.mark.parametrize("include_missing_mafs", [False, True]) def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: bool) -> None: - """Test that `VariantOverlapDetector.query()` MAF filtering is as expected. + """Test that `_InMemoryLookup.query()` MAF filtering is as expected. `include_missing_mafs` is parameterized in both the class constructor and in the query to demonstrate that it is only the query_method setting that changes the test results. """ - variant_overlap_detector = VariantOverlapDetector( - vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs + variant_overlap_detector: VariantLookup = VariantLookup( + vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs, cached=True ) query = variant_overlap_detector_query( variant_overlap_detector, @@ -598,9 +601,10 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: @pytest.mark.parametrize("include_missing_mafs", [False, True]) def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None: """Test that `FileBasedVariantLookup.query()` MAF filtering is as expected.""" - with FileBasedVariantLookup( - vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs + with VariantLookup( + vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs, cached=False ) as file_based_vcf_query: + # assert isinstance(file_based_vcf_query._lookup, _DiskBasedLookup) query = [ _round_simple_variant(simple_variant) for simple_variant in file_based_vcf_query.query( diff --git a/tests/primer3/test_primer3.py b/tests/primer3/test_primer3.py index c74cb13..48a1bee 100644 --- a/tests/primer3/test_primer3.py +++ b/tests/primer3/test_primer3.py @@ -11,7 +11,9 @@ from prymer.api.primer_pair import PrimerPair from prymer.api.span import Span from prymer.api.span import Strand -from prymer.api.variant_lookup import cached +from prymer.api.variant_lookup import VariantLookup +from prymer.api.variant_lookup import _DiskBasedLookup +from prymer.api.variant_lookup import _InMemoryLookup from prymer.primer3.primer3 import Primer3 from prymer.primer3.primer3 import Primer3Failure from prymer.primer3.primer3 import Primer3Result @@ -35,6 +37,32 @@ def vcf_path() -> Path: return Path(__file__).parent / "data" / "miniref.variants.vcf.gz" +@pytest.fixture() +def variant_masking_test_data() -> list[tuple[Span, str, str]]: + return [ + ( + Span(refname="chr2", start=9000, end=9110), + # 9000 9010 9020 9030 9040 9050 9060 9070 9080 9090 9100 9110 # noqa + "AATATTCTTGNTGCTTATGCNGCTGACATTGTTGCCCTCCCTAAAGCAACNAAGTAGCCTNTATTTCCCANAGTGAAAGANNACGCTGGCNNNTCAGTTANNNTACAAAAG", + "AATATTCTTGCTGCTTATGCAGCTGACATTGTTGCCCTCCCTAAAGCAACCAAGTAGCCTTTATTTCCCACAGTGAAAGAAAACGCTGGCCTATCAGTTACATTACAAAAG", + ), # expected masked positions: 9010, 9020, 9050, 9060, 9070, + # 9080 (2bp insertion: 3 bases), 9090 (2bp deletion: 2 bases), 9100 (mixed: 3 bases) + # do not expect positions 9000 (MAF = 0.001), 9030 (MAF = 0.001), or 9040 (MAF = 0.0004814) + # to be masked (MAF below the provided min_maf) + ( + Span(refname="chr2", start=9095, end=9120), + "AGTTANNNTACAAAAGGCAGATTTCA", + "AGTTACATTACAAAAGGCAGATTTCA", + ), + # 9100 (common-mixed -- alt1: CA->GG, and alt2: CA->CACACA). The first alt masks the + # positions [9100,9101], and the second alt masks the positions [9100,9102] (an extra + # base for the insertion). But the second alt is not added to variant lookup, while the + # first variant is classified as OTHER, so [9100,9102] are masked. FIXME: this could be + # improved by more faithfully parsing the input VCF and representing each alternate as its + # own simple variant. + ] + + @pytest.fixture def single_primer_params() -> PrimerAndAmpliconParameters: return PrimerAndAmpliconParameters( @@ -142,6 +170,11 @@ def valid_primer_pairs( return primer_pairs +################################################################################ +# Tests for Primer3.py:: design() +################################################################################ + + def test_design_raises( genome_ref: Path, single_primer_params: PrimerAndAmpliconParameters, @@ -370,53 +403,133 @@ def test_fasta_close_valid( designer.design(design_input=design_input) -@pytest.mark.parametrize( - "region, expected_hard_masked, expected_soft_masked", - [ - ( - Span(refname="chr2", start=9000, end=9110), - # 9000 9010 9020 9030 9040 9050 9060 9070 9080 9090 9100 9110 # noqa - "AATATTCTTGNTGCTTATGCNGCTGACATTGTTGCCCTCCCTAAAGCAACNAAGTAGCCTNTATTTCCCANAGTGAAAGANNACGCTGGCNNNTCAGTTANNNTACAAAAG", - "AATATTCTTGCTGCTTATGCAGCTGACATTGTTGCCCTCCCTAAAGCAACCAAGTAGCCTTTATTTCCCACAGTGAAAGAAAACGCTGGCCTATCAGTTACATTACAAAAG", - ), # expected masked positions: 9010, 9020, 9050, 9060, 9070, - # 9080 (2bp insertion: 3 bases), 9090 (2bp deletion: 2 bases), 9100 (mixed: 3 bases) - # do not expect positions 9000 (MAF = 0.001), 9030 (MAF = 0.001), or 9040 (MAF = 0.0004814) - # to be masked (MAF below the provided min_maf) - ( - Span(refname="chr2", start=9095, end=9120), - "AGTTANNNTACAAAAGGCAGATTTCA", - "AGTTACATTACAAAAGGCAGATTTCA", - ), - # 9100 (common-mixed -- alt1: CA->GG, and alt2: CA->CACACA). The first alt masks the - # positions [9100,9101], and the second alt masks the positions [9100,9102] (an extra - # base for the insertion). But the second alt is not added to variant lookup, while the - # first variant is classified as OTHER, so [9100,9102] are masked. FIXME: this could be - # improved by more faithfully parsing the input VCF and representing each alternate as its - # own simple variant. - ], -) -def test_variant_lookup( +################################################################################ +# Tests for variant_lookup.py::VariantLookup() +################################################################################ + + +@pytest.mark.parametrize("variant_test_data_index", [0, 1]) +def test_cached_variant_lookup( + variant_masking_test_data: list[tuple[Span, str, str]], + variant_test_data_index: int, genome_ref: Path, vcf_path: Path, - region: Span, - expected_hard_masked: str, - expected_soft_masked: str, ) -> None: - """Test that MAF filtering and masking are working as expected.""" - with Primer3( - genome_fasta=genome_ref, variant_lookup=cached([vcf_path], min_maf=0.01) - ) as designer: + """Test that _InMemoryLookup() is instantiated correctly and works as expected. + + Parameterizes test data from the `variant_masking_test_data` fixture (2 test cases).""" + region: Span = variant_masking_test_data[variant_test_data_index][0] + expected_hard_masked: str = variant_masking_test_data[variant_test_data_index][1] + expected_soft_masked: str = variant_masking_test_data[variant_test_data_index][2] + cached_variant_lookup: VariantLookup = VariantLookup( + vcf_paths=[vcf_path], min_maf=0.01, cached=True, include_missing_mafs=True + ) + + with Primer3(genome_fasta=genome_ref, variants=cached_variant_lookup) as designer: actual_soft_masked, actual_hard_masked = designer.get_design_sequences(region=region) + + assert isinstance(designer.variant_lookup._lookup, _InMemoryLookup) + assert not isinstance(designer.variant_lookup._lookup, _DiskBasedLookup) assert actual_hard_masked == expected_hard_masked assert actual_soft_masked == expected_soft_masked # with no variant lookup should all be soft-masked - with Primer3(genome_fasta=genome_ref, variant_lookup=None) as designer: + with Primer3(genome_fasta=genome_ref, variants=None) as designer: actual_soft_masked, actual_hard_masked = designer.get_design_sequences(region=region) assert actual_hard_masked == expected_soft_masked assert actual_soft_masked == expected_soft_masked +@pytest.mark.parametrize("variant_test_data_index", [0, 1]) +def test_diskbased_variant_lookup( + variant_masking_test_data: list[tuple[Span, str, str]], + variant_test_data_index: int, + genome_ref: Path, + vcf_path: Path, +) -> None: + """Test that _DiskBasedLookup() is instantiated correctly and works as expected. + + Parameterizes test data from the `variant_masking_test_data` fixture (2 test cases).""" + region: Span = variant_masking_test_data[variant_test_data_index][0] + expected_hard_masked: str = variant_masking_test_data[variant_test_data_index][1] + expected_soft_masked: str = variant_masking_test_data[variant_test_data_index][2] + disk_variant_lookup: VariantLookup = VariantLookup( + vcf_paths=[vcf_path], min_maf=0.01, cached=False, include_missing_mafs=True + ) + with Primer3(genome_fasta=genome_ref, variants=disk_variant_lookup) as designer: + actual_soft_masked, actual_hard_masked = designer.get_design_sequences(region=region) + + assert isinstance(designer.variant_lookup._lookup, _DiskBasedLookup) + assert not isinstance(designer.variant_lookup._lookup, _InMemoryLookup) + assert actual_hard_masked == expected_hard_masked + assert actual_soft_masked == expected_soft_masked + + # with no variant lookup --> should all be soft-masked + with Primer3(genome_fasta=genome_ref, variants=None) as designer: + actual_soft_masked, actual_hard_masked = designer.get_design_sequences(region=region) + assert actual_hard_masked == expected_soft_masked + assert actual_soft_masked == expected_soft_masked + + +def test_invalid_variant_lookup(genome_ref: Path, vcf_path: Path) -> None: + """Assert that we raise an error if `variants` is not a list of `Path` objects.""" + with pytest.raises(ValueError, match="Variant lookup is required to be"): + Primer3(genome_fasta=genome_ref, variants=[vcf_path, "invalid_item"]) # type: ignore + + +################################################################################ +# Tests for Primer3 methods: _create_design_region(), _screen_pair_results(), _build_failures() +################################################################################ + + +@pytest.mark.parametrize("max_amplicon_length", [100, 101]) +def test_create_design_region(max_amplicon_length: int, genome_ref: Path) -> None: + """If the target region is shorter than the max amplicon length, it should be padded to fit.""" + target_region = Span(refname="chr1", start=201, end=250, strand=Strand.POSITIVE) + + with Primer3(genome_fasta=genome_ref) as designer: + design_region: Span = designer._create_design_region( + target_region=target_region, + max_amplicon_length=max_amplicon_length, + min_primer_length=10, + ) + + assert design_region.length == 2 * max_amplicon_length - target_region.length + + +def test_create_design_region_raises_when_target_region_exceeds_max_amplicon_length( + genome_ref: Path, +) -> None: + """ + `_create_design_region()` should raise a ValueError when the target region is larger than the + max amplicon length. + """ + target_region = Span(refname="chr1", start=201, end=250, strand=Strand.POSITIVE) + + with Primer3(genome_fasta=genome_ref) as designer: + with pytest.raises(ValueError, match="exceeds the maximum size"): + designer._create_design_region( + target_region=target_region, max_amplicon_length=10, min_primer_length=10 + ) + + +def test_create_design_region_raises_when_primers_would_not_fit_in_design_region( + genome_ref: Path, +) -> None: + """ + `_create_design_region()` should raise a ValueError when the design region does not include + sufficient space flanking the target for a primer to be designed. (i.e. when this space is less + than the specified minimum primer length.) + """ + target_region = Span(refname="chr1", start=201, end=250, strand=Strand.POSITIVE) + + with Primer3(genome_fasta=genome_ref) as designer: + with pytest.raises(ValueError, match="exceeds the maximum size"): + designer._create_design_region( + target_region=target_region, max_amplicon_length=55, min_primer_length=10 + ) + + def test_screen_pair_results( valid_primer_pairs: list[PrimerPair], genome_ref: Path, @@ -542,6 +655,12 @@ def test_build_failures_debugs( assert expected_error_msg in caplog.text +################################################################################ +# Tests for Primer3Result methods: primers(), primer_pairs(), +# as_primer_result(), as_primer_pair_result() +################################################################################ + + def test_primer3_result_primers_ok( valid_left_primers: list[Oligo], valid_right_primers: list[Oligo] ) -> None: @@ -585,54 +704,9 @@ def test_primer3_result_as_primer_pair_result_exception( result.as_primer_pair_result() -@pytest.mark.parametrize("max_amplicon_length", [100, 101]) -def test_create_design_region(max_amplicon_length: int, genome_ref: Path) -> None: - """If the target region is shorter than the max amplicon length, it should be padded to fit.""" - target_region = Span(refname="chr1", start=201, end=250, strand=Strand.POSITIVE) - - with Primer3(genome_fasta=genome_ref) as designer: - design_region: Span = designer._create_design_region( - target_region=target_region, - max_amplicon_length=max_amplicon_length, - min_primer_length=10, - ) - - assert design_region.length == 2 * max_amplicon_length - target_region.length - - -def test_create_design_region_raises_when_target_region_exceeds_max_amplicon_length( - genome_ref: Path, -) -> None: - """ - `_create_design_region()` should raise a ValueError when the target region is larger than the - max amplicon length. - """ - target_region = Span(refname="chr1", start=201, end=250, strand=Strand.POSITIVE) - - with Primer3(genome_fasta=genome_ref) as designer: - with pytest.raises(ValueError, match="exceeds the maximum size"): - designer._create_design_region( - target_region=target_region, max_amplicon_length=10, min_primer_length=10 - ) - - -def test_create_design_region_raises_when_primers_would_not_fit_in_design_region( - genome_ref: Path, -) -> None: - """ - `_create_design_region()` should raise a ValueError when the design region does not include - sufficient space flanking the target for a primer to be designed. (i.e. when this space is less - than the specified minimum primer length.) - """ - target_region = Span(refname="chr1", start=201, end=250, strand=Strand.POSITIVE) - - with Primer3(genome_fasta=genome_ref) as designer: - with pytest.raises(ValueError, match="exceeds the maximum size"): - designer._create_design_region( - target_region=target_region, max_amplicon_length=55, min_primer_length=10 - ) - - +################################################################################ +# Tests for Primer3 probe design +################################################################################ def test_probe_design_raises(genome_ref: Path, valid_probe_params: ProbeParameters) -> None: """Test that we raise an error when the target region is smaller than the minimal probe size.""" target = Span(refname="chr1", start=201, end=217, strand=Strand.POSITIVE)