Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensure FileBasedVariantLookup is used as a context manager #71

Merged
merged 10 commits into from
Nov 13, 2024
53 changes: 49 additions & 4 deletions prymer/api/variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
import logging
from abc import ABC
from abc import abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass
from dataclasses import field
from enum import auto
from enum import unique
from pathlib import Path
from types import TracebackType
from typing import Optional
from typing import final

Expand All @@ -74,6 +76,7 @@
from pysam import VariantFile
from pysam import VariantRecord
from strenum import UppercaseStrEnum
from typing_extensions import override

from prymer.api.span import Span
from prymer.api.span import Strand
Expand Down Expand Up @@ -233,7 +236,7 @@ def build(simple_variant: SimpleVariant) -> "_VariantInterval":
)


class VariantLookup(ABC):
class VariantLookup(AbstractContextManager, ABC):
"""Base class to represent a variant from a given genomic range.

Attributes:
Expand Down Expand Up @@ -319,11 +322,36 @@ def to_variants(
def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Subclasses must implement this method."""

def close(self) -> None:
"""Close this variant lookup."""
return None

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit this context manager by closing the variant lookup."""
super().__exit__(exc_type, exc_value, traceback)
self.close()
return None

clintval marked this conversation as resolved.
Show resolved Hide resolved

class FileBasedVariantLookup(VariantLookup):
"""Implementation of VariantLookup that queries against indexed VCF files each time a query is
"""Implementation of `VariantLookup` that queries against indexed VCF files each time a query is
performed. Assumes the index is located adjacent to the VCF file and has the same base name with
either a .csi or .tbi suffix."""
either a .csi or .tbi suffix.

Example:

```python
>>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]

```
""" # noqa: E501

def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_missing_mafs: bool):
self._readers: list[VariantFile] = []
Expand Down Expand Up @@ -353,6 +381,12 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
simple_variants.extend(self.to_variants(variants, source_vcf=path))
return sorted(simple_variants, key=lambda x: x.pos)

@override
def close(self) -> None:
"""Close the underlying VCF file handles."""
for handle in self._readers:
handle.close()
clintval marked this conversation as resolved.
Show resolved Hide resolved


class VariantOverlapDetector(VariantLookup):
"""Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting
Expand Down Expand Up @@ -443,7 +477,18 @@ def disk_based(
vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False
) -> FileBasedVariantLookup:
"""Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup.
Appropriate for large VCFs."""

Appropriate for large VCFs.

Example:

```python
>>> with disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]

```
""" # noqa: E501
return FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs
)
67 changes: 42 additions & 25 deletions tests/api/test_variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import replace
from pathlib import Path
from typing import Optional
from typing import Type

import fgpyo.vcf.builder
import pytest
Expand All @@ -17,7 +16,6 @@
from prymer.api.span import Strand
from prymer.api.variant_lookup import FileBasedVariantLookup
from prymer.api.variant_lookup import SimpleVariant
from prymer.api.variant_lookup import VariantLookup
from prymer.api.variant_lookup import VariantOverlapDetector
from prymer.api.variant_lookup import VariantType
from prymer.api.variant_lookup import cached
Expand Down Expand Up @@ -435,13 +433,24 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor
assert actual_simple_variants == VALID_SIMPLE_VARIANTS_APPROX


@pytest.mark.parametrize("variant_lookup_class", [FileBasedVariantLookup, VariantOverlapDetector])
def test_simple_variant_conversion_logs(
variant_lookup_class: Type[VariantLookup], vcf_path: Path, caplog: pytest.LogCaptureFixture
def test_simple_variant_conversion_logs_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = variant_lookup_class(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
) as variant_lookup:
variant_lookup.query(refname="foo", start=1, end=2)
assert "No variants extracted from region of interest" in caplog.text


def test_simple_variant_conversion_logs_non_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = VariantOverlapDetector(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
)
variant_lookup.query(refname="foo", start=1, end=2)
Expand All @@ -451,15 +460,17 @@ def test_simple_variant_conversion_logs(
def test_missing_index_file_raises(temp_missing_path: Path) -> None:
"""Test that both VariantLookup objects raise an error with a missing index file."""
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)
with disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False):
pass
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)


def test_missing_vcf_files_raises() -> None:
"""Test that an error is raised when no VCF_paths are provided."""
with pytest.raises(ValueError, match="No VCF paths given to query"):
disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)
with disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False):
pass
with pytest.raises(ValueError, match="No VCF paths given to query"):
cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)

Expand All @@ -480,12 +491,12 @@ def test_vcf_header_missing_chrom(
caplog.set_level(logging.DEBUG)
vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf]
random.Random(random_seed).shuffle(vcf_paths)
variant_lookup = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True
)
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
) as variant_lookup:
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
# Should find all 12 variants from vcf_path (no filtering), with two variants having two
# alternate alleles
assert len(variants_of_interest) == 14
Expand Down Expand Up @@ -527,6 +538,9 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None:
vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=True
)

# test that we can close the variant overlap detector and make no side effects
variant_overlap_detector.close()

# query for all variants
assert VALID_SIMPLE_VARIANTS_APPROX == variant_overlap_detector_query(
variant_overlap_detector, refname="chr2", start=8000, end=9101
Expand All @@ -552,6 +566,9 @@ def test_variant_overlap_detector_query(vcf_path: Path) -> None:
variant_overlap_detector, refname="chr2", start=8000, end=9000
) == get_simple_variant_approx_by_id("complex-variant-sv-1/1", "rare-dbsnp-snp1-1/1")

# test that we can close the variant overlap detector and make no side effects
variant_overlap_detector.close()


@pytest.mark.parametrize("include_missing_mafs", [False, True])
def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs: bool) -> None:
Expand Down Expand Up @@ -587,19 +604,19 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs:
@pytest.mark.parametrize("include_missing_mafs", [False, True])
def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None:
"""Test that `FileBasedVariantLookup.query()` MAF filtering is as expected."""
file_based_vcf_query = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs
)
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]
) as file_based_vcf_query:
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]

if not include_missing_mafs:
assert query == get_simple_variant_approx_by_id(
Expand Down