Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensure FileBasedVariantLookup is used as a context manager #71

Merged
merged 10 commits into from
Nov 13, 2024
50 changes: 46 additions & 4 deletions prymer/api/variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
from enum import auto
from enum import unique
from pathlib import Path
from types import TracebackType
from typing import ContextManager
from typing import Optional
from typing import final

Expand Down Expand Up @@ -320,10 +322,20 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Subclasses must implement this method."""


class FileBasedVariantLookup(VariantLookup):
"""Implementation of VariantLookup that queries against indexed VCF files each time a query is
class FileBasedVariantLookup(ContextManager, VariantLookup):
clintval marked this conversation as resolved.
Show resolved Hide resolved
"""Implementation of `VariantLookup` that queries against indexed VCF files each time a query is
performed. Assumes the index is located adjacent to the VCF file and has the same base name with
either a .csi or .tbi suffix."""
either a .csi or .tbi suffix.

Example:

```python
>>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]

```
""" # noqa: E501

def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_missing_mafs: bool):
self._readers: list[VariantFile] = []
Expand All @@ -341,6 +353,20 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss
open_fh = pysam.VariantFile(str(path))
self._readers.append(open_fh)

def __enter__(self) -> "FileBasedVariantLookup":
"""Enter the context manager."""
return self

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit this context manager while closing the underlying VCF handles."""
self.close()
return None

def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Queries variants from the VCFs used by this lookup and returns a `SimpleVariant`."""
simple_variants: list[SimpleVariant] = []
Expand All @@ -353,6 +379,11 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
simple_variants.extend(self.to_variants(variants, source_vcf=path))
return sorted(simple_variants, key=lambda x: x.pos)

def close(self) -> None:
"""Close the underlying VCF file handles."""
for handle in self._readers:
handle.close()
clintval marked this conversation as resolved.
Show resolved Hide resolved


class VariantOverlapDetector(VariantLookup):
"""Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting
Expand Down Expand Up @@ -443,7 +474,18 @@ def disk_based(
vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False
) -> FileBasedVariantLookup:
"""Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup.
Appropriate for large VCFs."""

Appropriate for large VCFs.

Example:

```python
>>> with disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]

```
""" # noqa: E501
return FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs
)
44 changes: 24 additions & 20 deletions tests/api/test_picking.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,29 @@ def test_build_primer_pairs_fails_when_primers_on_wrong_reference(
assert next(picks) is not None

with pytest.raises(ValueError, match="Left primers exist on different reference"):
_picks = list(picking.build_primer_pairs(
left_primers=invalid_lefts,
right_primers=valid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
))
_picks = list(
picking.build_primer_pairs(
left_primers=invalid_lefts,
right_primers=valid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
)
)

with pytest.raises(ValueError, match="Right primers exist on different reference"):
_picks = list(picking.build_primer_pairs(
left_primers=valid_lefts,
right_primers=invalid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
))
_picks = list(
picking.build_primer_pairs(
left_primers=valid_lefts,
right_primers=invalid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
)
)
61 changes: 36 additions & 25 deletions tests/api/test_variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import replace
from pathlib import Path
from typing import Optional
from typing import Type

import fgpyo.vcf.builder
import pytest
Expand All @@ -17,7 +16,6 @@
from prymer.api.span import Strand
from prymer.api.variant_lookup import FileBasedVariantLookup
from prymer.api.variant_lookup import SimpleVariant
from prymer.api.variant_lookup import VariantLookup
from prymer.api.variant_lookup import VariantOverlapDetector
from prymer.api.variant_lookup import VariantType
from prymer.api.variant_lookup import cached
Expand Down Expand Up @@ -435,13 +433,24 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor
assert actual_simple_variants == VALID_SIMPLE_VARIANTS_APPROX


@pytest.mark.parametrize("variant_lookup_class", [FileBasedVariantLookup, VariantOverlapDetector])
def test_simple_variant_conversion_logs(
variant_lookup_class: Type[VariantLookup], vcf_path: Path, caplog: pytest.LogCaptureFixture
def test_simple_variant_conversion_logs_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = variant_lookup_class(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
) as variant_lookup:
variant_lookup.query(refname="foo", start=1, end=2)
assert "No variants extracted from region of interest" in caplog.text


def test_simple_variant_conversion_logs_non_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = VariantOverlapDetector(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
)
variant_lookup.query(refname="foo", start=1, end=2)
Expand All @@ -451,15 +460,17 @@ def test_simple_variant_conversion_logs(
def test_missing_index_file_raises(temp_missing_path: Path) -> None:
"""Test that both VariantLookup objects raise an error with a missing index file."""
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)
with disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False):
pass
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)


def test_missing_vcf_files_raises() -> None:
"""Test that an error is raised when no VCF_paths are provided."""
with pytest.raises(ValueError, match="No VCF paths given to query"):
disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)
with disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False):
pass
with pytest.raises(ValueError, match="No VCF paths given to query"):
cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)

Expand All @@ -480,12 +491,12 @@ def test_vcf_header_missing_chrom(
caplog.set_level(logging.DEBUG)
vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf]
random.Random(random_seed).shuffle(vcf_paths)
variant_lookup = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True
)
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
) as variant_lookup:
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
# Should find all 12 variants from vcf_path (no filtering), with two variants having two
# alternate alleles
assert len(variants_of_interest) == 14
Expand Down Expand Up @@ -587,19 +598,19 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs:
@pytest.mark.parametrize("include_missing_mafs", [False, True])
def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None:
"""Test that `FileBasedVariantLookup.query()` MAF filtering is as expected."""
file_based_vcf_query = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs
)
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]
) as file_based_vcf_query:
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]

if not include_missing_mafs:
assert query == get_simple_variant_approx_by_id(
Expand Down
Loading