Skip to content

Commit

Permalink
revert: revert the decision to put ContextManager in parent
Browse files Browse the repository at this point in the history
  • Loading branch information
clintval committed Nov 12, 2024
1 parent 421e2ed commit bd8313c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 46 deletions.
34 changes: 14 additions & 20 deletions prymer/api/variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@
import logging
from abc import ABC
from abc import abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass
from dataclasses import field
from enum import auto
from enum import unique
from pathlib import Path
from types import TracebackType
from typing import ContextManager
from typing import Optional
from typing import final

Expand All @@ -76,7 +76,6 @@
from pysam import VariantFile
from pysam import VariantRecord
from strenum import UppercaseStrEnum
from typing_extensions import override

from prymer.api.span import Span
from prymer.api.span import Strand
Expand Down Expand Up @@ -236,7 +235,7 @@ def build(simple_variant: SimpleVariant) -> "_VariantInterval":
)


class VariantLookup(AbstractContextManager, ABC):
class VariantLookup(ABC):
"""Base class to represent a variant from a given genomic range.
Attributes:
Expand Down Expand Up @@ -322,23 +321,8 @@ def to_variants(
def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Subclasses must implement this method."""

def close(self) -> None:
"""Close this variant lookup."""
return None

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit this context manager by closing the variant lookup."""
super().__exit__(exc_type, exc_value, traceback)
self.close()
return None


class FileBasedVariantLookup(VariantLookup):
class FileBasedVariantLookup(ContextManager, VariantLookup):
"""Implementation of `VariantLookup` that queries against indexed VCF files each time a query is
performed. Assumes the index is located adjacent to the VCF file and has the same base name with
either a .csi or .tbi suffix.
Expand Down Expand Up @@ -381,12 +365,22 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
simple_variants.extend(self.to_variants(variants, source_vcf=path))
return sorted(simple_variants, key=lambda x: x.pos)

@override
def close(self) -> None:
"""Close the underlying VCF file handles."""
for handle in self._readers:
handle.close()

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit this context manager while closing the underlying VCF handles."""
super().__exit__(exc_type, exc_value, traceback)
self.close()
return None


class VariantOverlapDetector(VariantLookup):
"""Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting
Expand Down
44 changes: 24 additions & 20 deletions tests/api/test_picking.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,29 @@ def test_build_primer_pairs_fails_when_primers_on_wrong_reference(
assert next(picks) is not None

with pytest.raises(ValueError, match="Left primers exist on different reference"):
_picks = list(picking.build_primer_pairs(
left_primers=invalid_lefts,
right_primers=valid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
))
_picks = list(
picking.build_primer_pairs(
left_primers=invalid_lefts,
right_primers=valid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
)
)

with pytest.raises(ValueError, match="Right primers exist on different reference"):
_picks = list(picking.build_primer_pairs(
left_primers=valid_lefts,
right_primers=invalid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
))
_picks = list(
picking.build_primer_pairs(
left_primers=valid_lefts,
right_primers=invalid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
)
)
6 changes: 0 additions & 6 deletions tests/api/test_variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,6 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None:
vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=True
)

# test that we can close the variant overlap detector and make no side effects
variant_overlap_detector.close()

# query for all variants
assert VALID_SIMPLE_VARIANTS_APPROX == variant_overlap_detector_query(
variant_overlap_detector, refname="chr2", start=8000, end=9101
Expand All @@ -566,9 +563,6 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None:
variant_overlap_detector, refname="chr2", start=8000, end=9000
) == get_simple_variant_approx_by_id("complex-variant-sv-1/1", "rare-dbsnp-snp1-1/1")

# test that we can close the variant overlap detector and make no side effects
variant_overlap_detector.close()


@pytest.mark.parametrize("include_missing_mafs", [False, True])
def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: bool) -> None:
Expand Down

0 comments on commit bd8313c

Please sign in to comment.