From 00fa54b988710464150a8bed9d8196b51a6bbaa1 Mon Sep 17 00:00:00 2001 From: clintval Date: Thu, 10 Oct 2024 13:15:05 -0700 Subject: [PATCH 1/8] feat: ensure FileBasedVariantLookup is used as a context manager --- prymer/api/variant_lookup.py | 55 ++++++++++++++-- tests/api/test_variant_lookup.py | 106 +++++++++++++++++-------------- 2 files changed, 111 insertions(+), 50 deletions(-) diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 82f5833..0e8980c 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -59,12 +59,15 @@ import logging from abc import ABC from abc import abstractmethod +from contextlib import AbstractContextManager from dataclasses import dataclass from dataclasses import field from enum import auto from enum import unique from pathlib import Path +from types import TracebackType from typing import Optional +from typing import Self from typing import final import pysam @@ -320,10 +323,20 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Subclasses must implement this method.""" -class FileBasedVariantLookup(VariantLookup): - """Implementation of VariantLookup that queries against indexed VCF files each time a query is +class FileBasedVariantLookup(VariantLookup, AbstractContextManager): + """Implementation of `VariantLookup` that queries against indexed VCF files each time a query is performed. Assumes the index is located adjacent to the VCF file and has the same base name with - either a .csi or .tbi suffix.""" + either a .csi or .tbi suffix. + + Example: + + ```python + >>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup: + ... lookup.query(refname="chr2", start=7999, end=8000) + [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] + + ``` + """ # noqa: E501 def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_missing_mafs: bool): self._readers: list[VariantFile] = [] @@ -353,6 +366,26 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: simple_variants.extend(self.to_variants(variants, source_vcf=path)) return sorted(simple_variants, key=lambda x: x.pos) + def __enter__(self) -> Self: + """Enter this context manager.""" + super().__enter__() + return self + + def close(self) -> None: + """Close the underlying VCF file handles.""" + for handle in self._readers: + handle.close() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit this context manager and close all underlying VCF handles.""" + super().__exit__(exc_type, exc_value, traceback) + self.close() + class VariantOverlapDetector(VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting @@ -443,7 +476,21 @@ def disk_based( vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False ) -> FileBasedVariantLookup: """Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup. - Appropriate for large VCFs.""" + + Appropriate for large VCFs. Ensure that you take advantage of [`contextlib.closing`](https://docs.python.org/3/library/contextlib.html#contextlib.closing) + for automatically closing the file-base variant lookup after it is used. See below for an + example. + + Example: + + ```python + >>> from contextlib import closing + >>> with closing(disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0)) as lookup: + ... lookup.query(refname="chr2", start=7999, end=8000) + [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] + + ``` + """ # noqa: E501 return FileBasedVariantLookup( vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs ) diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index a6eda6c..298f3b4 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -1,10 +1,10 @@ import logging import random +from contextlib import closing from dataclasses import dataclass from dataclasses import replace from pathlib import Path from typing import Optional -from typing import Type import fgpyo.vcf.builder import pytest @@ -17,7 +17,6 @@ from prymer.api.span import Strand from prymer.api.variant_lookup import FileBasedVariantLookup from prymer.api.variant_lookup import SimpleVariant -from prymer.api.variant_lookup import VariantLookup from prymer.api.variant_lookup import VariantOverlapDetector from prymer.api.variant_lookup import VariantType from prymer.api.variant_lookup import cached @@ -435,13 +434,24 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor assert actual_simple_variants == VALID_SIMPLE_VARIANTS_APPROX -@pytest.mark.parametrize("variant_lookup_class", [FileBasedVariantLookup, VariantOverlapDetector]) -def test_simple_variant_conversion_logs( - variant_lookup_class: Type[VariantLookup], vcf_path: Path, caplog: pytest.LogCaptureFixture +def test_simple_variant_conversion_logs_file_based( + vcf_path: Path, caplog: pytest.LogCaptureFixture ) -> None: """Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert.""" caplog.set_level(logging.DEBUG) - variant_lookup = variant_lookup_class( + with FileBasedVariantLookup( + vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False + ) as variant_lookup: + variant_lookup.query(refname="foo", start=1, end=2) + assert "No variants extracted from region of interest" in caplog.text + + +def test_simple_variant_conversion_logs_non_file_based( + vcf_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert.""" + caplog.set_level(logging.DEBUG) + variant_lookup = VariantOverlapDetector( vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False ) variant_lookup.query(refname="foo", start=1, end=2) @@ -451,7 +461,10 @@ def test_simple_variant_conversion_logs( def test_missing_index_file_raises(temp_missing_path: Path) -> None: """Test that both VariantLookup objects raise an error with a missing index file.""" with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): - disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) + with closing( + disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) + ): + pass with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) @@ -459,7 +472,8 @@ def test_missing_index_file_raises(temp_missing_path: Path) -> None: def test_missing_vcf_files_raises() -> None: """Test that an error is raised when no VCF_paths are provided.""" with pytest.raises(ValueError, match="No VCF paths given to query"): - disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) + with closing(disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)): + pass with pytest.raises(ValueError, match="No VCF paths given to query"): cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) @@ -480,17 +494,17 @@ def test_vcf_header_missing_chrom( caplog.set_level(logging.DEBUG) vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf] random.Random(random_seed).shuffle(vcf_paths) - variant_lookup = FileBasedVariantLookup( + with FileBasedVariantLookup( vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True - ) - variants_of_interest = variant_lookup.query( - refname="chr2", start=7999, end=9900 - ) # (chr2 only in vcf_path) - # Should find all 12 variants from vcf_path (no filtering), with two variants having two - # alternate alleles - assert len(variants_of_interest) == 14 - expected_error_msg = "does not contain chromosome" - assert expected_error_msg in caplog.text + ) as variant_lookup: + variants_of_interest = variant_lookup.query( + refname="chr2", start=7999, end=9900 + ) # (chr2 only in vcf_path) + # Should find all 12 variants from vcf_path (no filtering), with two variants having two + # alternate alleles + assert len(variants_of_interest) == 14 + expected_error_msg = "does not contain chromosome" + assert expected_error_msg in caplog.text @pytest.mark.parametrize("test_case", VALID_SIMPLE_VARIANT_TEST_CASES) @@ -587,32 +601,32 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: @pytest.mark.parametrize("include_missing_mafs", [False, True]) def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None: """Test that `FileBasedVariantLookup.query()` MAF filtering is as expected.""" - file_based_vcf_query = FileBasedVariantLookup( + with FileBasedVariantLookup( vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs - ) - query = [ - _round_simple_variant(simple_variant) - for simple_variant in file_based_vcf_query.query( - refname="chr2", - start=8000, - end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100 - maf=0.05, - include_missing_mafs=include_missing_mafs, - ) - ] - - if not include_missing_mafs: - assert query == get_simple_variant_approx_by_id( - "common-multiallelic-1/2", - "common-multiallelic-2/2", - "common-mixed-1/2", - "common-mixed-2/2", - ) - else: - assert query == get_simple_variant_approx_by_id( - "complex-variant-sv-1/1", - "common-multiallelic-1/2", - "common-multiallelic-2/2", - "common-mixed-1/2", - "common-mixed-2/2", - ) + ) as file_based_vcf_query: + query = [ + _round_simple_variant(simple_variant) + for simple_variant in file_based_vcf_query.query( + refname="chr2", + start=8000, + end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100 + maf=0.05, + include_missing_mafs=include_missing_mafs, + ) + ] + + if not include_missing_mafs: + assert query == get_simple_variant_approx_by_id( + "common-multiallelic-1/2", + "common-multiallelic-2/2", + "common-mixed-1/2", + "common-mixed-2/2", + ) + else: + assert query == get_simple_variant_approx_by_id( + "complex-variant-sv-1/1", + "common-multiallelic-1/2", + "common-multiallelic-2/2", + "common-mixed-1/2", + "common-mixed-2/2", + ) From 47ceaa5a2128f62e45a6c5b63bc49f80dc26dfde Mon Sep 17 00:00:00 2001 From: clintval Date: Tue, 15 Oct 2024 08:10:56 -0700 Subject: [PATCH 2/8] chore: address review comments --- prymer/api/variant_lookup.py | 2 +- tests/api/test_variant_lookup.py | 40 ++++++++++++++++---------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 0e8980c..2bc080a 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -478,7 +478,7 @@ def disk_based( """Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup. Appropriate for large VCFs. Ensure that you take advantage of [`contextlib.closing`](https://docs.python.org/3/library/contextlib.html#contextlib.closing) - for automatically closing the file-base variant lookup after it is used. See below for an + for automatically closing the file-based variant lookup after it is used. See below for an example. Example: diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index 298f3b4..14bfa43 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -500,11 +500,11 @@ def test_vcf_header_missing_chrom( variants_of_interest = variant_lookup.query( refname="chr2", start=7999, end=9900 ) # (chr2 only in vcf_path) - # Should find all 12 variants from vcf_path (no filtering), with two variants having two - # alternate alleles - assert len(variants_of_interest) == 14 - expected_error_msg = "does not contain chromosome" - assert expected_error_msg in caplog.text + # Should find all 12 variants from vcf_path (no filtering), with two variants having two + # alternate alleles + assert len(variants_of_interest) == 14 + expected_error_msg = "does not contain chromosome" + assert expected_error_msg in caplog.text @pytest.mark.parametrize("test_case", VALID_SIMPLE_VARIANT_TEST_CASES) @@ -615,18 +615,18 @@ def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> ) ] - if not include_missing_mafs: - assert query == get_simple_variant_approx_by_id( - "common-multiallelic-1/2", - "common-multiallelic-2/2", - "common-mixed-1/2", - "common-mixed-2/2", - ) - else: - assert query == get_simple_variant_approx_by_id( - "complex-variant-sv-1/1", - "common-multiallelic-1/2", - "common-multiallelic-2/2", - "common-mixed-1/2", - "common-mixed-2/2", - ) + if not include_missing_mafs: + assert query == get_simple_variant_approx_by_id( + "common-multiallelic-1/2", + "common-multiallelic-2/2", + "common-mixed-1/2", + "common-mixed-2/2", + ) + else: + assert query == get_simple_variant_approx_by_id( + "complex-variant-sv-1/1", + "common-multiallelic-1/2", + "common-multiallelic-2/2", + "common-mixed-1/2", + "common-mixed-2/2", + ) From ae2ee2300a939eb52f83f7c5d88412ffa370bc54 Mon Sep 17 00:00:00 2001 From: clintval Date: Tue, 15 Oct 2024 08:12:36 -0700 Subject: [PATCH 3/8] fix: remove now-failing Mambaforge --- .github/workflows/tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e53bd31..8e0eca3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,7 +35,6 @@ jobs: - name: Set up miniconda uses: conda-incubator/setup-miniconda@v3 with: - miniforge-variant: Mambaforge miniforge-version: latest channels: conda-forge,bioconda activate-environment: prymer From 680ff17507c95f7703dddf05423e171d5d8d8cea Mon Sep 17 00:00:00 2001 From: clintval Date: Tue, 15 Oct 2024 16:37:04 -0700 Subject: [PATCH 4/8] chore: appease the coderabbit --- prymer/api/variant_lookup.py | 14 ++------------ tests/api/test_variant_lookup.py | 7 ++----- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 2bc080a..e7b151e 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -67,7 +67,6 @@ from pathlib import Path from types import TracebackType from typing import Optional -from typing import Self from typing import final import pysam @@ -366,11 +365,6 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: simple_variants.extend(self.to_variants(variants, source_vcf=path)) return sorted(simple_variants, key=lambda x: x.pos) - def __enter__(self) -> Self: - """Enter this context manager.""" - super().__enter__() - return self - def close(self) -> None: """Close the underlying VCF file handles.""" for handle in self._readers: @@ -383,7 +377,6 @@ def __exit__( traceback: Optional[TracebackType], ) -> None: """Exit this context manager and close all underlying VCF handles.""" - super().__exit__(exc_type, exc_value, traceback) self.close() @@ -477,15 +470,12 @@ def disk_based( ) -> FileBasedVariantLookup: """Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup. - Appropriate for large VCFs. Ensure that you take advantage of [`contextlib.closing`](https://docs.python.org/3/library/contextlib.html#contextlib.closing) - for automatically closing the file-based variant lookup after it is used. See below for an - example. + Appropriate for large VCFs. Example: ```python - >>> from contextlib import closing - >>> with closing(disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0)) as lookup: + >>> with disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0) as lookup: ... lookup.query(refname="chr2", start=7999, end=8000) [SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='', end=8000, variant_type=, maf=None)] diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index 14bfa43..164ee16 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -1,6 +1,5 @@ import logging import random -from contextlib import closing from dataclasses import dataclass from dataclasses import replace from pathlib import Path @@ -461,9 +460,7 @@ def test_simple_variant_conversion_logs_non_file_based( def test_missing_index_file_raises(temp_missing_path: Path) -> None: """Test that both VariantLookup objects raise an error with a missing index file.""" with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): - with closing( - disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) - ): + with disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False): pass with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"): cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False) @@ -472,7 +469,7 @@ def test_missing_index_file_raises(temp_missing_path: Path) -> None: def test_missing_vcf_files_raises() -> None: """Test that an error is raised when no VCF_paths are provided.""" with pytest.raises(ValueError, match="No VCF paths given to query"): - with closing(disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)): + with disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False): pass with pytest.raises(ValueError, match="No VCF paths given to query"): cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False) From 3e6ad5ad1422b8e971951273fff6198423ecf000 Mon Sep 17 00:00:00 2001 From: clintval Date: Wed, 16 Oct 2024 10:09:20 -0700 Subject: [PATCH 5/8] feat: make VariantLookup the context manager instead --- prymer/api/variant_lookup.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index e7b151e..965b83b 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -76,6 +76,7 @@ from pysam import VariantFile from pysam import VariantRecord from strenum import UppercaseStrEnum +from typing_extensions import override from prymer.api.span import Span from prymer.api.span import Strand @@ -235,7 +236,7 @@ def build(simple_variant: SimpleVariant) -> "_VariantInterval": ) -class VariantLookup(ABC): +class VariantLookup(AbstractContextManager, ABC): """Base class to represent a variant from a given genomic range. Attributes: @@ -321,8 +322,23 @@ def to_variants( def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Subclasses must implement this method.""" + def close(self) -> None: + """Close this variant lookup.""" + return None + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit this context manager by closing the variant lookup.""" + super().__exit__(exc_type, exc_value, traceback) + self.close() + return None -class FileBasedVariantLookup(VariantLookup, AbstractContextManager): + +class FileBasedVariantLookup(VariantLookup): """Implementation of `VariantLookup` that queries against indexed VCF files each time a query is performed. Assumes the index is located adjacent to the VCF file and has the same base name with either a .csi or .tbi suffix. @@ -365,20 +381,12 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: simple_variants.extend(self.to_variants(variants, source_vcf=path)) return sorted(simple_variants, key=lambda x: x.pos) + @override def close(self) -> None: """Close the underlying VCF file handles.""" for handle in self._readers: handle.close() - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - """Exit this context manager and close all underlying VCF handles.""" - self.close() - class VariantOverlapDetector(VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting From 909a4b5698a210daa8a87ea2b341ef406cf66121 Mon Sep 17 00:00:00 2001 From: clintval Date: Wed, 16 Oct 2024 10:17:20 -0700 Subject: [PATCH 6/8] chore: try to increase test coverage over new method --- tests/api/test_variant_lookup.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index 164ee16..3b20a45 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -538,6 +538,9 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None: vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=True ) + # test that we can close the variant overlap detector and make no side effects + variant_overlap_detector.close() + # query for all variants assert VALID_SIMPLE_VARIANTS_APPROX == variant_overlap_detector_query( variant_overlap_detector, refname="chr2", start=8000, end=9101 @@ -563,6 +566,9 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None: variant_overlap_detector, refname="chr2", start=8000, end=9000 ) == get_simple_variant_approx_by_id("complex-variant-sv-1/1", "rare-dbsnp-snp1-1/1") + # test that we can close the variant overlap detector and make no side effects + variant_overlap_detector.close() + @pytest.mark.parametrize("include_missing_mafs", [False, True]) def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: bool) -> None: From bd8313c2614ba171dff63cc88889341ceb8b5ccc Mon Sep 17 00:00:00 2001 From: clintval Date: Tue, 12 Nov 2024 12:31:31 -0800 Subject: [PATCH 7/8] revert: revert the decision to put ContextManager in parent --- prymer/api/variant_lookup.py | 34 ++++++++++-------------- tests/api/test_picking.py | 44 +++++++++++++++++--------------- tests/api/test_variant_lookup.py | 6 ----- 3 files changed, 38 insertions(+), 46 deletions(-) diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 965b83b..891ed39 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -59,13 +59,13 @@ import logging from abc import ABC from abc import abstractmethod -from contextlib import AbstractContextManager from dataclasses import dataclass from dataclasses import field from enum import auto from enum import unique from pathlib import Path from types import TracebackType +from typing import ContextManager from typing import Optional from typing import final @@ -76,7 +76,6 @@ from pysam import VariantFile from pysam import VariantRecord from strenum import UppercaseStrEnum -from typing_extensions import override from prymer.api.span import Span from prymer.api.span import Strand @@ -236,7 +235,7 @@ def build(simple_variant: SimpleVariant) -> "_VariantInterval": ) -class VariantLookup(AbstractContextManager, ABC): +class VariantLookup(ABC): """Base class to represent a variant from a given genomic range. Attributes: @@ -322,23 +321,8 @@ def to_variants( def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Subclasses must implement this method.""" - def close(self) -> None: - """Close this variant lookup.""" - return None - - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - """Exit this context manager by closing the variant lookup.""" - super().__exit__(exc_type, exc_value, traceback) - self.close() - return None - -class FileBasedVariantLookup(VariantLookup): +class FileBasedVariantLookup(ContextManager, VariantLookup): """Implementation of `VariantLookup` that queries against indexed VCF files each time a query is performed. Assumes the index is located adjacent to the VCF file and has the same base name with either a .csi or .tbi suffix. @@ -381,12 +365,22 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: simple_variants.extend(self.to_variants(variants, source_vcf=path)) return sorted(simple_variants, key=lambda x: x.pos) - @override def close(self) -> None: """Close the underlying VCF file handles.""" for handle in self._readers: handle.close() + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit this context manager while closing the underlying VCF handles.""" + super().__exit__(exc_type, exc_value, traceback) + self.close() + return None + class VariantOverlapDetector(VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting diff --git a/tests/api/test_picking.py b/tests/api/test_picking.py index 164a5ed..cdb5da0 100644 --- a/tests/api/test_picking.py +++ b/tests/api/test_picking.py @@ -383,25 +383,29 @@ def test_build_primer_pairs_fails_when_primers_on_wrong_reference( assert next(picks) is not None with pytest.raises(ValueError, match="Left primers exist on different reference"): - _picks = list(picking.build_primer_pairs( - left_primers=invalid_lefts, - right_primers=valid_rights, - target=target, - amplicon_sizes=MinOptMax(0, 100, 500), - amplicon_tms=MinOptMax(0, 80, 150), - max_heterodimer_tm=None, - weights=weights, - fasta_path=fasta, - )) + _picks = list( + picking.build_primer_pairs( + left_primers=invalid_lefts, + right_primers=valid_rights, + target=target, + amplicon_sizes=MinOptMax(0, 100, 500), + amplicon_tms=MinOptMax(0, 80, 150), + max_heterodimer_tm=None, + weights=weights, + fasta_path=fasta, + ) + ) with pytest.raises(ValueError, match="Right primers exist on different reference"): - _picks = list(picking.build_primer_pairs( - left_primers=valid_lefts, - right_primers=invalid_rights, - target=target, - amplicon_sizes=MinOptMax(0, 100, 500), - amplicon_tms=MinOptMax(0, 80, 150), - max_heterodimer_tm=None, - weights=weights, - fasta_path=fasta, - )) + _picks = list( + picking.build_primer_pairs( + left_primers=valid_lefts, + right_primers=invalid_rights, + target=target, + amplicon_sizes=MinOptMax(0, 100, 500), + amplicon_tms=MinOptMax(0, 80, 150), + max_heterodimer_tm=None, + weights=weights, + fasta_path=fasta, + ) + ) diff --git a/tests/api/test_variant_lookup.py b/tests/api/test_variant_lookup.py index 3b20a45..164ee16 100644 --- a/tests/api/test_variant_lookup.py +++ b/tests/api/test_variant_lookup.py @@ -538,9 +538,6 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None: vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=True ) - # test that we can close the variant overlap detector and make no side effects - variant_overlap_detector.close() - # query for all variants assert VALID_SIMPLE_VARIANTS_APPROX == variant_overlap_detector_query( variant_overlap_detector, refname="chr2", start=8000, end=9101 @@ -566,9 +563,6 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None: variant_overlap_detector, refname="chr2", start=8000, end=9000 ) == get_simple_variant_approx_by_id("complex-variant-sv-1/1", "rare-dbsnp-snp1-1/1") - # test that we can close the variant overlap detector and make no side effects - variant_overlap_detector.close() - @pytest.mark.parametrize("include_missing_mafs", [False, True]) def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: bool) -> None: From b82e455533360a75b372a4566c59f2e064a2e6c0 Mon Sep 17 00:00:00 2001 From: clintval Date: Tue, 12 Nov 2024 12:44:29 -0800 Subject: [PATCH 8/8] chore: fixup a few review issues --- prymer/api/variant_lookup.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 891ed39..74794e5 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -353,6 +353,20 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss open_fh = pysam.VariantFile(str(path)) self._readers.append(open_fh) + def __enter__(self) -> "FileBasedVariantLookup": + """Enter the context manager.""" + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit this context manager while closing the underlying VCF handles.""" + self.close() + return None + def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Queries variants from the VCFs used by this lookup and returns a `SimpleVariant`.""" simple_variants: list[SimpleVariant] = [] @@ -370,17 +384,6 @@ def close(self) -> None: for handle in self._readers: handle.close() - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - """Exit this context manager while closing the underlying VCF handles.""" - super().__exit__(exc_type, exc_value, traceback) - self.close() - return None - class VariantOverlapDetector(VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting