diff --git a/prymer/api/variant_lookup.py b/prymer/api/variant_lookup.py index 891ed39..74794e5 100644 --- a/prymer/api/variant_lookup.py +++ b/prymer/api/variant_lookup.py @@ -353,6 +353,20 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss open_fh = pysam.VariantFile(str(path)) self._readers.append(open_fh) + def __enter__(self) -> "FileBasedVariantLookup": + """Enter the context manager.""" + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Exit this context manager while closing the underlying VCF handles.""" + self.close() + return None + def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]: """Queries variants from the VCFs used by this lookup and returns a `SimpleVariant`.""" simple_variants: list[SimpleVariant] = [] @@ -370,17 +384,6 @@ def close(self) -> None: for handle in self._readers: handle.close() - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - """Exit this context manager while closing the underlying VCF handles.""" - super().__exit__(exc_type, exc_value, traceback) - self.close() - return None - class VariantOverlapDetector(VariantLookup): """Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting