diff --git a/.pylintrc b/.pylintrc index 4f16f2ec..b3f7a572 100644 --- a/.pylintrc +++ b/.pylintrc @@ -28,7 +28,7 @@ unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code -extension-pkg-whitelist=biograph._capi,tabix,pysam,intervaltree,edlib,setproctitle +extension-pkg-whitelist=tabix,pysam,intervaltree,edlib,setproctitle,pyabpoa,pywfa [MESSAGES CONTROL] diff --git a/repo_utils/answer_key/phab/phab_result_poa.vcf.gz b/repo_utils/answer_key/phab/phab_result_poa.vcf.gz index 588a0548..0a25a9a6 100644 Binary files a/repo_utils/answer_key/phab/phab_result_poa.vcf.gz and b/repo_utils/answer_key/phab/phab_result_poa.vcf.gz differ diff --git a/truvari/msatovcf.py b/truvari/msatovcf.py index 83c4649f..6224d9fc 100644 --- a/truvari/msatovcf.py +++ b/truvari/msatovcf.py @@ -144,7 +144,7 @@ def msa2vcf(msa, anchor_base='N'): >>> msa_dir = "repo_utils/test_files/external/fake_mafft/lookup/" >>> msa_file = "fm_ca43b50e2a5d770bb34202d8a7b62421.msa" >>> seqs = open(msa_dir + msa_file).read() - >>> fasta = {n:s.decode() for n, s in fasta_reader(seqs, False)} + >>> fasta = dict(fasta_reader(seqs)) >>> m_entries_str = truvari.msa2vcf(fasta) """ ref_key = [_ for _ in msa.keys() if _.startswith("ref_")][0] diff --git a/truvari/phab.py b/truvari/phab.py index 21ea2271..bde708dc 100644 --- a/truvari/phab.py +++ b/truvari/phab.py @@ -8,7 +8,7 @@ import logging import argparse import multiprocessing -from io import BytesIO +from io import BytesIO, StringIO from functools import partial from collections import defaultdict @@ -16,7 +16,7 @@ import pyabpoa from pysam import samtools from intervaltree import IntervalTree -from pywfa.align import WavefrontAligner # pylint: disable=no-name-in-module +from pywfa.align import WavefrontAligner import truvari DEFAULT_MAFFT_PARAM = "--auto --thread 1" @@ -153,28 +153,21 @@ def make_haplotype_jobs(base_vcf, bSamples=None, comp_vcf=None, cSamples=None, p return ret, samp_names -def fasta_reader(fa_str, name_entries=True): +def fasta_reader(fa_str): """ Parses a fasta file as a string and yields tuples of (location, entry) - if name_entries, the entry names are written to the value and location is honored """ cur_name = None - cur_entry = BytesIO() + cur_entry = StringIO() for i in fa_str.split('\n'): if not i.startswith(">"): - cur_entry.write(i.encode()) + cur_entry.write(i) continue if cur_name is not None: - cur_entry.write(b'\n') cur_entry.seek(0) yield cur_name, cur_entry.read() cur_name = i[1:] - cur_entry = BytesIO() - if name_entries: - cur_name = cur_name.split('_')[-1] - cur_entry.write((i + '\n').encode()) - - cur_entry.write(b'\n') + cur_entry = StringIO() cur_entry.seek(0) yield cur_name, cur_entry.read() @@ -228,7 +221,7 @@ def wfa_to_vars(seq_bytes): Align haplotypes independently with WFA Much faster than mafft, but may be less accurate at finding parsimonous representations """ - fasta = {k: v.decode() for k, v in fasta_reader(seq_bytes.decode(), False)} + fasta = dict(fasta_reader(seq_bytes.decode())) ref_key = [_ for _ in fasta.keys() if _.startswith("ref_")][0] reference = fasta[ref_key] @@ -265,9 +258,7 @@ def mafft_to_vars(seq_bytes, params=DEFAULT_MAFFT_PARAM): with open("repo_utils/test_files/external/fake_mafft/lookup/fm_" + dev_name + ".msa", 'w') as fout: fout.write(ret.stdout) - fasta = {} - for name, sequence in fasta_reader(ret.stdout, name_entries=False): - fasta[name] = sequence.decode() + fasta = dict(fasta_reader(ret.stdout)) return truvari.msa2vcf(fasta) def poa_to_vars(seq_bytes): @@ -275,45 +266,16 @@ def poa_to_vars(seq_bytes): Run partial order alignment to create msa """ parts = [] - for k,v in fasta_reader(seq_bytes.decode(), False): - s = v.decode().strip() - parts.append((len(s), s, k)) - parts.sort() + for k,v in fasta_reader(seq_bytes.decode()): + parts.append((len(v), v, k)) + parts.sort(reverse=True) _, seqs, names = zip(*parts) aligner = pyabpoa.msa_aligner() aln_result = aligner.msa(seqs, False, True) return truvari.msa2vcf(dict(zip(names, aln_result.msa_seq))) -def harmonize_variants(harm_jobs, mafft_params, base_vcf, samp_names, output_fn, threads, method="mafft"): - """ - Parallel processing of variants to harmonize. Writes to output - """ - if method == "mafft": - align_method = partial(mafft_to_vars, params=mafft_params) - elif method == "poa": - align_method = poa_to_vars - else: - align_method = wfa_to_vars - - with open(output_fn[:-len(".gz")], 'w') as fout: - fout.write( - '##fileformat=VCFv4.1\n##FORMAT=\n') - for ctg in pysam.VariantFile(base_vcf).header.contigs.values(): - fout.write(str(ctg.header_record)) - fout.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t") - fout.write("\t".join(samp_names) + "\n") - - with multiprocessing.Pool(threads) as pool: - for result in pool.imap_unordered(align_method, harm_jobs): - fout.write(result) - pool.close() - pool.join() - - truvari.compress_index_vcf(output_fn[:-len(".gz")], output_fn) - - -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, too-many-locals # This is just how many arguments it takes def phab(var_regions, base_vcf, ref_fn, output_fn, bSamples=None, buffer=100, mafft_params=DEFAULT_MAFFT_PARAM, comp_vcf=None, cSamples=None, @@ -357,9 +319,29 @@ def phab(var_regions, base_vcf, ref_fn, output_fn, bSamples=None, buffer=100, haplotypes = collect_haplotypes(ref_haps_fn, hap_jobs, threads) logging.info("Harmonizing variants") - harmonize_variants(haplotypes, mafft_params, base_vcf, - samp_names, output_fn, threads, method) -# pylint: enable=too-many-arguments + if method == "mafft": + align_method = partial(mafft_to_vars, params=mafft_params) + elif method == "wfa": + align_method = wfa_to_vars + else: + align_method = poa_to_vars + + with open(output_fn[:-len(".gz")], 'w') as fout: + fout.write(('##fileformat=VCFv4.1\n' + '##FORMAT=\n')) + for ctg in pysam.VariantFile(base_vcf).header.contigs.values(): + fout.write(str(ctg.header_record)) + fout.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t") + fout.write("\t".join(samp_names) + "\n") + + with multiprocessing.Pool(threads) as pool: + for result in pool.imap_unordered(align_method, haplotypes): + fout.write(result) + pool.close() + pool.join() + + truvari.compress_index_vcf(output_fn[:-len(".gz")], output_fn) +# pylint: enable=too-many-arguments, too-many-locals ###### # UI #