Skip to content

Commit

Permalink
Got ProgressManager working for schemecreate
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisgKent committed Apr 25, 2024
1 parent a6f5788 commit 6bcc6d5
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 64 deletions.
4 changes: 4 additions & 0 deletions primalscheme3/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Module imports
from primalscheme3.__init__ import __version__
from primalscheme3.core.progress_tracker import ProgressManager
from primalscheme3.flu.flu_main import create_flu
from primalscheme3.interaction.interaction import visulise_interactions
from primalscheme3.panel.panel_main import PanelRunModes, panelcreate
Expand Down Expand Up @@ -469,6 +470,8 @@ def cli():
f"ERROR: Output directory '{args.output}' already exists. Use --force to override"
)

pm = ProgressManager()

if args.func == schemecreate:
validate_scheme_create_args(args)
schemecreate(
Expand All @@ -493,6 +496,7 @@ def cli():
output_dir=args.output,
backtrack=args.backtrack,
circular=args.circular,
progress_manager=pm,
)
elif args.func == schemereplace:
validate_scheme_replace_args(args)
Expand Down
114 changes: 55 additions & 59 deletions primalscheme3/core/digestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
from collections import Counter
from enum import Enum
from multiprocessing import Pool
from typing import Callable, Union

import networkx as nx
Expand All @@ -11,9 +10,6 @@
# Submodules
from primaldimer_py import do_pools_interact_py # type: ignore

# Externals
from tqdm import tqdm

from primalscheme3.core.classes import FKmer, PrimerPair, RKmer
from primalscheme3.core.errors import (
ERROR_SET,
Expand All @@ -25,6 +21,7 @@
WalksTooFar,
)
from primalscheme3.core.get_window import get_r_window_FAST2
from primalscheme3.core.progress_tracker import ProgressManager
from primalscheme3.core.seq_functions import expand_ambs, get_most_common_base
from primalscheme3.core.thermo import (
THERMORESULT,
Expand Down Expand Up @@ -132,6 +129,7 @@ def generate_valid_primerpairs(
amplicon_size_max: int,
dimerscore: float,
msa_index: int,
progress_manager: ProgressManager,
disable_progress_bar: bool = False,
) -> list[PrimerPair]:
"""Generates valid primer pairs for a given set of forward and reverse kmers.
Expand All @@ -149,10 +147,9 @@ def generate_valid_primerpairs(
## Generate all primerpairs without checking
checked_pp = []

for fkmer in tqdm(
fkmers,
desc="Generating PrimerPairs",
disable=disable_progress_bar,
for fkmer in progress_manager.create_sub_progress(
iter=fkmers,
process="Generating PrimerPairs",
):
fkmer_start = min(fkmer.starts())
# Get all rkmers that would make a valid amplicon
Expand Down Expand Up @@ -459,7 +456,7 @@ def process_seqs(
return error

# Remove Ns if asked
if not ignore_n:
if ignore_n:
seq_counts.pop(DIGESTION_ERROR.CONTAINS_INVALID_BASE, None)

total_values = sum(seq_counts.values())
Expand Down Expand Up @@ -495,11 +492,10 @@ def mp_r_digest(

# Count how many times each sequence / error occurs
_start_col, seq_counts = r_digest_to_count((align_array, cfg, start_col, min_freq))

tmp_parsed_seqs = process_seqs(seq_counts, min_freq, ignore_n=cfg["ignore_n"])
if type(tmp_parsed_seqs) == DIGESTION_ERROR:
if isinstance(tmp_parsed_seqs, DIGESTION_ERROR):
return (start_col, tmp_parsed_seqs)
elif type(tmp_parsed_seqs) == dict:
elif isinstance(tmp_parsed_seqs, dict):
parsed_seqs = tmp_parsed_seqs
else:
raise ValueError("Unknown error occured")
Expand Down Expand Up @@ -725,6 +721,7 @@ def reduce_kmers(seqs: set[str], max_edit_dist: int = 1, end_3p: int = 6) -> set
def digest(
msa_array: np.ndarray,
cfg: dict,
progress_manager: ProgressManager,
indexes: tuple[list[int], list[int]] | None = None,
logger: None = None,
) -> tuple[list[FKmer], list[RKmer]]:
Expand Down Expand Up @@ -756,53 +753,52 @@ def digest(
else range(msa_array.shape[1] - cfg["primer_size_min"])
)

# Create the MP Pool
with Pool(cfg["n_cores"]) as p:
# Generate the FKmers via MP
fprimer_mp = p.map(
mp_f_digest,
((msa_array, cfg, end_col, cfg["minbasefreq"]) for end_col in findexes),
)

pass_fprimer_mp = [x for x in fprimer_mp if type(x) is FKmer and x.seqs]
pass_fprimer_mp.sort(key=lambda fkmer: fkmer.end)
# Digest the findexes
fkmers = []
for findex in progress_manager.create_sub_progress(
iter=findexes, process="FKMER Digestion"
):
fkmer = mp_f_digest((msa_array, cfg, findex, cfg["minbasefreq"]))

# Generate the FKmers via MP
rprimer_mp = p.map(
mp_r_digest,
((msa_array, cfg, start_col, cfg["minbasefreq"]) for start_col in rindexes),
)
pass_rprimer_mp = [x for x in rprimer_mp if type(x) is RKmer and x.seqs]
# mp_thermo_pass_rkmers = [x for x in rprimer_mp if x is not None]
pass_rprimer_mp.sort(key=lambda rkmer: rkmer.start)
if logger is not None:
if isinstance(fkmer, tuple):
logger.debug(
"FKmer: <red>{end_col}\t{error}</>",
end_col=fkmer[0],
error=fkmer[1].value,
)
else:
logger.debug(
"FKmer: <green>{end_col}</>: AllPass",
end_col=fkmer.end, # type: ignore
)

# Append valid FKmers
if isinstance(fkmer, FKmer) and fkmer.seqs:
fkmers.append(fkmer)

# Digest the rindexes
rkmers = []
for rindex in progress_manager.create_sub_progress(
iter=rindexes, process="Digesting RKmers"
):
rkmer = mp_r_digest((msa_array, cfg, rindex, cfg["minbasefreq"]))

# If a logger has been provided dumb the error stats
if logger is not None:
# Log the fkmer errors
for fkmer_result in fprimer_mp:
if type(fkmer_result) is tuple:
logger.debug(
"FKmer: <red>{end_col}\t{error}</>",
end_col=fkmer_result[0],
error=fkmer_result[1].value,
)
else:
logger.debug(
"FKmer: <green>{end_col}</>: AllPass",
end_col=fkmer_result.end, # type: ignore
)
# log the rkmer errors
for rkmer_result in rprimer_mp:
if type(rkmer_result) is tuple:
logger.debug(
"RKmer: <red>{start_col}\t{error}</>",
start_col=rkmer_result[0],
error=rkmer_result[1].value,
)
else:
logger.debug(
"RKmer: <green>{start_col}</>: AllPass",
start_col=rkmer_result.start, # type: ignore
)

return (pass_fprimer_mp, pass_rprimer_mp)
if isinstance(rkmer, tuple):
logger.debug(
"RKmer: <red>{start_col}\t{error}</>",
start_col=rkmer[0],
error=rkmer[1].value,
)
else:
logger.debug(
"RKmer: <green>{start_col}</>: AllPass",
start_col=rkmer.start, # type: ignore
)

# Append valid RKmers
if isinstance(rkmer, RKmer) and rkmer.seqs:
rkmers.append(rkmer)

return (fkmers, rkmers)
14 changes: 11 additions & 3 deletions primalscheme3/core/msa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pathlib
from uuid import uuid4

import sys
from uuid import uuid4

import numpy as np
from Bio import SeqIO
Expand Down Expand Up @@ -46,12 +45,19 @@ class MSA:
primerpairs: list[PrimerPair]

def __init__(
self, name: str, path: pathlib.Path, msa_index: int, mapping: str, logger=None
self,
name: str,
path: pathlib.Path,
msa_index: int,
mapping: str,
progress_manager,
logger=None,
) -> None:
self.name = name
self.path = str(path)
self.msa_index = msa_index
self.logger = logger
self.progress_manager = progress_manager

# Read in the MSA
records_index = SeqIO.index(self.path, "fasta")
Expand Down Expand Up @@ -99,6 +105,7 @@ def digest(
cfg=cfg,
indexes=indexes,
logger=self.logger,
progress_manager=self.progress_manager,
)
# remap the fkmer and rkmers if needed
if self._mapping_array is not None:
Expand Down Expand Up @@ -129,6 +136,7 @@ def generate_primerpairs(
amplicon_size_max=amplicon_size_max,
dimerscore=dimerscore,
msa_index=self.msa_index,
progress_manager=self.progress_manager,
)
# Update primerpairs to include the chrom_name and amplicon_prefix
for primerpair in self.primerpairs:
Expand Down
52 changes: 52 additions & 0 deletions primalscheme3/core/progress_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import time

from tqdm import tqdm


class ProgressTracker(tqdm):
def __init__(self, process, iterable, *args, **kwargs):
self.process = process
super().__init__(iterable, *args, **kwargs)


class ProgressManager:
_subprocess: None | ProgressTracker

def __init__(self):
self._status = None
self._subprocess = None

def n(self) -> int | None:
if self._subprocess:
return self._subprocess.n
return None

def total(self) -> int | None:
if self._subprocess:
return self._subprocess.total
return None

def process(self) -> str | None:
if self._subprocess:
return self._subprocess.process
return None

def create_sub_progress(self, iter, process, *args, **kwargs) -> ProgressTracker:
"""Create a progress tracker"""
self._subprocess = ProgressTracker(
*args, iterable=iter, process=process, **kwargs, desc=process
)
return self._subprocess


if __name__ == "__main__":
pm = ProgressManager()

print(pm.process(), pm.n(), pm.total())

for _ in pm.create_sub_progress(iter=range(10), process="test"):
print(_)
print(pm.process(), pm.n(), pm.total())
time.sleep(0.1)

print(pm.n())
10 changes: 8 additions & 2 deletions primalscheme3/scheme/scheme_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# Interaction checker
from primaldimer_py import do_pools_interact_py # type: ignore


from primalscheme3.__init__ import __version__
from primalscheme3.core.bedfiles import (
read_in_bedprimerpairs,
Expand All @@ -23,6 +22,7 @@
from primalscheme3.core.mapping import generate_consensus, generate_reference
from primalscheme3.core.mismatches import MatchDB
from primalscheme3.core.msa import MSA
from primalscheme3.core.progress_tracker import ProgressManager
from primalscheme3.scheme.classes import Scheme, SchemeReturn


Expand Down Expand Up @@ -85,7 +85,9 @@ def schemereplace(
prefix, ampliconnumber = primername.split("_")[:2]
primerstem = f"{ampliconnumber}_{prefix}"
except ValueError:
raise ValueError(f"ERROR: {primername} cannot be parsed using _ as delim")
raise ValueError(
f"ERROR: {primername} cannot be parsed using _ as delim"
) from None

# Find primernumber from bedfile
wanted_pp = None
Expand All @@ -110,6 +112,8 @@ def schemereplace(
path=msapath,
msa_index=wanted_pp.msa_index,
mapping=cfg["mapping"],
logger=None,
progress_manager=None,
)
# Check the hashes match
with open(msa.path, "rb") as f:
Expand Down Expand Up @@ -244,6 +248,7 @@ def schemecreate(
circular: bool,
backtrack: bool,
ignore_n: bool,
progress_manager: ProgressManager,
bedfile: pathlib.Path | None = None,
force: bool = False,
mapping: str = "first",
Expand Down Expand Up @@ -385,6 +390,7 @@ def schemecreate(
msa_index=msa_index,
mapping=cfg["mapping"],
logger=logger,
progress_manager=progress_manager,
)

if "/" in msa._chrom_name:
Expand Down

0 comments on commit 6bcc6d5

Please sign in to comment.