Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
mainly phab. consolidated harmonize_variants and simplified the fasta_reader
  • Loading branch information
ACEnglish committed Jan 6, 2024
1 parent a54a08f commit 062162a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ unsafe-load-any-extension=no
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code
extension-pkg-whitelist=biograph._capi,tabix,pysam,intervaltree,edlib,setproctitle
extension-pkg-whitelist=tabix,pysam,intervaltree,edlib,setproctitle,pyabpoa,pywfa


[MESSAGES CONTROL]
Expand Down
Binary file modified repo_utils/answer_key/phab/phab_result_poa.vcf.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion truvari/msatovcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def msa2vcf(msa, anchor_base='N'):
>>> msa_dir = "repo_utils/test_files/external/fake_mafft/lookup/"
>>> msa_file = "fm_ca43b50e2a5d770bb34202d8a7b62421.msa"
>>> seqs = open(msa_dir + msa_file).read()
>>> fasta = {n:s.decode() for n, s in fasta_reader(seqs, False)}
>>> fasta = dict(fasta_reader(seqs))
>>> m_entries_str = truvari.msa2vcf(fasta)
"""
ref_key = [_ for _ in msa.keys() if _.startswith("ref_")][0]
Expand Down
88 changes: 35 additions & 53 deletions truvari/phab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import logging
import argparse
import multiprocessing
from io import BytesIO
from io import BytesIO, StringIO
from functools import partial
from collections import defaultdict

import pysam
import pyabpoa
from pysam import samtools
from intervaltree import IntervalTree
from pywfa.align import WavefrontAligner # pylint: disable=no-name-in-module
from pywfa.align import WavefrontAligner
import truvari

DEFAULT_MAFFT_PARAM = "--auto --thread 1"
Expand Down Expand Up @@ -153,28 +153,21 @@ def make_haplotype_jobs(base_vcf, bSamples=None, comp_vcf=None, cSamples=None, p
return ret, samp_names


def fasta_reader(fa_str, name_entries=True):
def fasta_reader(fa_str):
"""
Parses a fasta file as a string and yields tuples of (location, entry)
if name_entries, the entry names are written to the value and location is honored
"""
cur_name = None
cur_entry = BytesIO()
cur_entry = StringIO()
for i in fa_str.split('\n'):
if not i.startswith(">"):
cur_entry.write(i.encode())
cur_entry.write(i)
continue
if cur_name is not None:
cur_entry.write(b'\n')
cur_entry.seek(0)
yield cur_name, cur_entry.read()
cur_name = i[1:]
cur_entry = BytesIO()
if name_entries:
cur_name = cur_name.split('_')[-1]
cur_entry.write((i + '\n').encode())

cur_entry.write(b'\n')
cur_entry = StringIO()
cur_entry.seek(0)
yield cur_name, cur_entry.read()

Expand Down Expand Up @@ -228,7 +221,7 @@ def wfa_to_vars(seq_bytes):
Align haplotypes independently with WFA
Much faster than mafft, but may be less accurate at finding parsimonous representations
"""
fasta = {k: v.decode() for k, v in fasta_reader(seq_bytes.decode(), False)}
fasta = dict(fasta_reader(seq_bytes.decode()))
ref_key = [_ for _ in fasta.keys() if _.startswith("ref_")][0]
reference = fasta[ref_key]

Expand Down Expand Up @@ -265,55 +258,24 @@ def mafft_to_vars(seq_bytes, params=DEFAULT_MAFFT_PARAM):
with open("repo_utils/test_files/external/fake_mafft/lookup/fm_" + dev_name + ".msa", 'w') as fout:
fout.write(ret.stdout)

fasta = {}
for name, sequence in fasta_reader(ret.stdout, name_entries=False):
fasta[name] = sequence.decode()
fasta = dict(fasta_reader(ret.stdout))
return truvari.msa2vcf(fasta)

def poa_to_vars(seq_bytes):
"""
Run partial order alignment to create msa
"""
parts = []
for k,v in fasta_reader(seq_bytes.decode(), False):
s = v.decode().strip()
parts.append((len(s), s, k))
parts.sort()
for k,v in fasta_reader(seq_bytes.decode()):
parts.append((len(v), v, k))
parts.sort(reverse=True)
_, seqs, names = zip(*parts)
aligner = pyabpoa.msa_aligner()
aln_result = aligner.msa(seqs, False, True)
return truvari.msa2vcf(dict(zip(names, aln_result.msa_seq)))


def harmonize_variants(harm_jobs, mafft_params, base_vcf, samp_names, output_fn, threads, method="mafft"):
"""
Parallel processing of variants to harmonize. Writes to output
"""
if method == "mafft":
align_method = partial(mafft_to_vars, params=mafft_params)
elif method == "poa":
align_method = poa_to_vars
else:
align_method = wfa_to_vars

with open(output_fn[:-len(".gz")], 'w') as fout:
fout.write(
'##fileformat=VCFv4.1\n##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n')
for ctg in pysam.VariantFile(base_vcf).header.contigs.values():
fout.write(str(ctg.header_record))
fout.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t")
fout.write("\t".join(samp_names) + "\n")

with multiprocessing.Pool(threads) as pool:
for result in pool.imap_unordered(align_method, harm_jobs):
fout.write(result)
pool.close()
pool.join()

truvari.compress_index_vcf(output_fn[:-len(".gz")], output_fn)


# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments, too-many-locals
# This is just how many arguments it takes
def phab(var_regions, base_vcf, ref_fn, output_fn, bSamples=None, buffer=100,
mafft_params=DEFAULT_MAFFT_PARAM, comp_vcf=None, cSamples=None,
Expand Down Expand Up @@ -357,9 +319,29 @@ def phab(var_regions, base_vcf, ref_fn, output_fn, bSamples=None, buffer=100,
haplotypes = collect_haplotypes(ref_haps_fn, hap_jobs, threads)

logging.info("Harmonizing variants")
harmonize_variants(haplotypes, mafft_params, base_vcf,
samp_names, output_fn, threads, method)
# pylint: enable=too-many-arguments
if method == "mafft":
align_method = partial(mafft_to_vars, params=mafft_params)
elif method == "wfa":
align_method = wfa_to_vars
else:
align_method = poa_to_vars

with open(output_fn[:-len(".gz")], 'w') as fout:
fout.write(('##fileformat=VCFv4.1\n'
'##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n'))
for ctg in pysam.VariantFile(base_vcf).header.contigs.values():
fout.write(str(ctg.header_record))
fout.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t")
fout.write("\t".join(samp_names) + "\n")

with multiprocessing.Pool(threads) as pool:
for result in pool.imap_unordered(align_method, haplotypes):
fout.write(result)
pool.close()
pool.join()

truvari.compress_index_vcf(output_fn[:-len(".gz")], output_fn)
# pylint: enable=too-many-arguments, too-many-locals

######
# UI #
Expand Down

0 comments on commit 062162a

Please sign in to comment.