diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 82f5833..74794e5 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -64,6 +64,8 @@ from enum import auto from enum import unique from pathlib import Path +from types import TracebackType +from typing import ContextManager from typing import Optional from typing import final @@ -320,10 +322,20 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Subclasses must implement this method.""" -class FileBasedVariantLookup(VariantLookup): - """Implementation of VariantLookup that queries against indexed VCF files each time a query is +class FileBasedVariantLookup(ContextManager, VariantLookup): + """Implementation of `VariantLookup` that queries against indexed VCF files each time a query is performed. Assumes the index is located adjacent to the VCF file and has the same base name with - either a .csi or .tbi suffix.""" + either a .csi or .tbi suffix. + + Example: + + ```python + >>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup: + ... lookup.query(refname="chr2", start=7999, end=8000) + [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] + + ``` + """ # noqa: E501 def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_missing_mafs: bool): self._readers: list[VariantFile] = [] @@ -341,6 +353,20 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss open_fh = pysam.VariantFile(str(path)) self._readers.append(open_fh) + def __enter__(self) -> "FileBasedVariantLookup": + """Enter the context manager.""" + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit this context manager while closing the underlying VCF handles.""" + self.close() + return None + def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Queries variants from the VCFs used by this lookup and returns a `SimpleVariant`.""" simple_variants: list[SimpleVariant] = [] @@ -353,6 +379,11 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: simple_variants.extend(self.to_variants(variants, source_vcf=path)) return sorted(simple_variants, key=lambda x: x.pos) + def close(self) -> None: + """Close the underlying VCF file handles.""" + for handle in self._readers: + handle.close() + class VariantOverlapDetector(VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting @@ -443,7 +474,18 @@ def disk_based( vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False ) -> FileBasedVariantLookup: """Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup. - Appropriate for large VCFs.""" + + Appropriate for large VCFs. + + Example: + + ```python + >>> with disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0) as lookup: + ... lookup.query(refname="chr2", start=7999, end=8000) + [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] + + ``` + """ # noqa: E501 return FileBasedVariantLookup( vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs ) diff --git a/tests/api/test_picking.py b/tests/api/test_picking.py index 164a5ed..cdb5da0 100644 --- a/tests/api/test_picking.py +++ b/tests/api/test_picking.py @@ -383,25 +383,29 @@ def test_build_primer_pairs_fails_when_primers_on_wrong_reference( assert next(picks) is not None with pytest.raises(ValueError, match="Left primers exist on different reference"): - _picks = list(picking.build_primer_pairs( - left_primers=invalid_lefts, - right_primers=valid_rights, - target=target, - amplicon_sizes=MinOptMax(0, 100, 500), - amplicon_tms=MinOptMax(0, 80, 150), - max_heterodimer_tm=None, - weights=weights, - fasta_path=fasta, - )) + _picks = list( + picking.build_primer_pairs( + left_primers=invalid_lefts, + right_primers=valid_rights, + target=target, + amplicon_sizes=MinOptMax(0, 100, 500), + amplicon_tms=MinOptMax(0, 80, 150), + max_heterodimer_tm=None, + weights=weights, + fasta_path=fasta, + ) + ) with pytest.raises(ValueError, match="Right primers exist on different reference"): - _picks = list(picking.build_primer_pairs( - left_primers=valid_lefts, - right_primers=invalid_rights, - target=target, - amplicon_sizes=MinOptMax(0, 100, 500), - amplicon_tms=MinOptMax(0, 80, 150), - max_heterodimer_tm=None, - weights=weights, - fasta_path=fasta, - )) + _picks = list( + picking.build_primer_pairs( + left_primers=valid_lefts, + right_primers=invalid_rights, + target=target, + amplicon_sizes=MinOptMax(0, 100, 500), + amplicon_tms=MinOptMax(0, 80, 150), + max_heterodimer_tm=None, + weights=weights, + fasta_path=fasta, + ) + ) diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index a6eda6c..164ee16 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -4,7 +4,6 @@ from dataclasses import replace from pathlib import Path from typing import Optional -from typing import Type import fgpyo.vcf.builder import pytest @@ -17,7 +16,6 @@ from prymer.api.span import Strand from prymer.api.variant_lookup import FileBasedVariantLookup from prymer.api.variant_lookup import SimpleVariant -from prymer.api.variant_lookup import VariantLookup from prymer.api.variant_lookup import VariantOverlapDetector from prymer.api.variant_lookup import VariantType from prymer.api.variant_lookup import cached @@ -435,13 +433,24 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor assert actual_simple_variants == VALID_SIMPLE_VARIANTS_APPROX -@pytest.mark.parametrize("variant_lookup_class", [FileBasedVariantLookup, VariantOverlapDetector]) -def test_simple_variant_conversion_logs( - variant_lookup_class: Type[VariantLookup], vcf_path: Path, caplog: pytest.LogCaptureFixture +def test_simple_variant_conversion_logs_file_based( + vcf_path: Path, caplog: pytest.LogCaptureFixture ) -> None: """Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert.""" caplog.set_level(logging.DEBUG) - variant_lookup = variant_lookup_class( + with FileBasedVariantLookup( + vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False + ) as variant_lookup: + variant_lookup.query(refname="foo", start=1, end=2) + assert "No variants extracted from region of interest" in caplog.text + + +def test_simple_variant_conversion_logs_non_file_based( + vcf_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert.""" + caplog.set_level(logging.DEBUG) + variant_lookup = VariantOverlapDetector( vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False ) variant_lookup.query(refname="foo", start=1, end=2) @@ -451,7 +460,8 @@ def test_simple_variant_conversion_logs( def test_missing_index_file_raises(temp_missing_path: Path) -> None: """Test that both VariantLookup objects raise an error with a missing index file.""" with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): - disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) + with disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False): + pass with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) @@ -459,7 +469,8 @@ def test_missing_index_file_raises(temp_missing_path: Path) -> None: def test_missing_vcf_files_raises() -> None: """Test that an error is raised when no VCF_paths are provided.""" with pytest.raises(ValueError, match="No VCF paths given to query"): - disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) + with disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False): + pass with pytest.raises(ValueError, match="No VCF paths given to query"): cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) @@ -480,12 +491,12 @@ def test_vcf_header_missing_chrom( caplog.set_level(logging.DEBUG) vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf] random.Random(random_seed).shuffle(vcf_paths) - variant_lookup = FileBasedVariantLookup( + with FileBasedVariantLookup( vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True - ) - variants_of_interest = variant_lookup.query( - refname="chr2", start=7999, end=9900 - ) # (chr2 only in vcf_path) + ) as variant_lookup: + variants_of_interest = variant_lookup.query( + refname="chr2", start=7999, end=9900 + ) # (chr2 only in vcf_path) # Should find all 12 variants from vcf_path (no filtering), with two variants having two # alternate alleles assert len(variants_of_interest) == 14 @@ -587,19 +598,19 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: @pytest.mark.parametrize("include_missing_mafs", [False, True]) def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None: """Test that `FileBasedVariantLookup.query()` MAF filtering is as expected.""" - file_based_vcf_query = FileBasedVariantLookup( + with FileBasedVariantLookup( vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs - ) - query = [ - _round_simple_variant(simple_variant) - for simple_variant in file_based_vcf_query.query( - refname="chr2", - start=8000, - end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100 - maf=0.05, - include_missing_mafs=include_missing_mafs, - ) - ] + ) as file_based_vcf_query: + query = [ + _round_simple_variant(simple_variant) + for simple_variant in file_based_vcf_query.query( + refname="chr2", + start=8000, + end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100 + maf=0.05, + include_missing_mafs=include_missing_mafs, + ) + ] if not include_missing_mafs: assert query == get_simple_variant_approx_by_id(