diff --git a/config_examples/parameters.yml b/config_examples/parameters.yml index 187f5009..70cca577 100644 --- a/config_examples/parameters.yml +++ b/config_examples/parameters.yml @@ -1,25 +1,40 @@ --- +model: { + non_strand_specific: { + use_module: True, + mode: mean + }, + import_model_from: models.deepsea + class: DeepSEA +} sampler: !obj:selene.samplers.IntervalsSampler { + genome: /scratch/data_hg/male.hg19.fasta, + genomic_features: /scratch/data_hg/sorted_sv_aggregate.bed.gz, + distinct_features: /scratch/data_hg/distinct_features.txt, + sample_from_regions: /scratch/data_hg/TFs_coords_only.txt, test_holdout: [8, 9], validation_holdout: [6, 7], random_seed: 127, sequence_length: 1001, center_bin_to_predict: 201, - default_threshold: 0.5, + feature_thresholds: 0.5, mode: "train" - } +} model_controller: !obj:selene.ModelController { batch_size: 64, max_steps: 500000, - report_metrics_every_n_steps: 16000, - n_validation_samples: 3200, + report_stats_every_n_steps: 16000, + n_validation_samples: 32000, optional_args: { cpu_n_threads: 32, use_cuda: True, data_parallel: False - }, + logging_verbosity: 2 + }, checkpoint: { resume: False - } - } + }, + output_dir: /tigress/TROYANSKAYA/kathy/example_outputs +} +evaluate_on_test: True ... diff --git a/config_examples/paths.yml b/config_examples/paths.yml deleted file mode 100644 index 8f2a88c1..00000000 --- a/config_examples/paths.yml +++ /dev/null @@ -1,9 +0,0 @@ ---- -dir_path: /scratch/data_hg/ -files: - genome: male.hg19.fasta - genomic_features: sorted_sv_aggregate.bed.gz - sample_from_regions: TFs_coords_only.txt - distinct_features: distinct_features.txt -output_dir: /tigress/TROYANSKAYA/kathy/example_outputs -... diff --git a/models/non_strand_specific_module.py b/models/non_strand_specific_module.py new file mode 100644 index 00000000..1cffd650 --- /dev/null +++ b/models/non_strand_specific_module.py @@ -0,0 +1,44 @@ +import torch +from torch.nn.modules import Module + + +def flip(x, dim): + """Reverses the elements in a given dimension `dim` of the Tensor. + + source: https://github.com/pytorch/pytorch/issues/229 + """ + xsize = x.size() + dim = x.dim() + dim if dim < 0 else dim + x = x.contiguous() + x = x.view(-1, *xsize[dim:]) + x = x.view( + x.size(0), x.size(1), -1)[:, getattr( + torch.arange(x.size(1)-1, -1, -1), + ('cpu','cuda')[x.is_cuda])().long(), :] + return x.view(xsize) + + +class NonStrandSpecific(Module): + def __init__(self, model, mode="mean"): + super(NonStrandSpecific, self).__init__() + + self.model = model + + if mode != "mean" and mode != "max": + raise ValueError(f"Mode should be one of 'mean' or 'max' but was" + "{mode!r}.") + self.mode = mode + + def forward(self, input): + + reverse_input = flip( + flip(input, 1), 2) + + output = self.model.forward(input) + output_from_rev = self.model.forward( + reverse_input) + if self.mode == "mean": + return (output + output_from_rev) / 2 + else: + return torch.max(output, output_from_rev) + diff --git a/selene.py b/selene.py index e4de0668..572f4009 100644 --- a/selene.py +++ b/selene.py @@ -6,126 +6,61 @@ Saves model to a user-specified output file. Usage: - selene.py - - [-s | --stdout] [--verbosity=] + selene.py selene.py -h | --help Options: -h --help Show this screen. - Import the module containing the model - Must be a model class in the imported module Choose the optimizer's learning rate - Input data and output filepaths - Model-specific parameters - -s --stdout Will also output logging information to stdout - [default: False] - --verbosity= Logging verbosity level (0=WARN, 1=INFO, 2=DEBUG) - [default: 1] + Model-specific parameters """ import importlib -import logging -import os -from time import strftime, time from docopt import docopt import torch -from selene.model_train import ModelController -from selene.samplers import IntervalsSampler -from selene.utils import initialize_logger, read_yaml_file -from selene.utils import load, load_path, instantiate +from selene.utils import load_path, instantiate if __name__ == "__main__": arguments = docopt( __doc__, version="1.0") - import_model_from = arguments[""] - model_class_name = arguments[""] - use_module = importlib.import_module(import_model_from) - model_class = getattr(use_module, model_class_name) - lr = float(arguments[""]) - paths = read_yaml_file( - arguments[""]) - - train_model = load_path(arguments[""], instantiate=False) - - - ################################################## - # PATHS - ################################################## - dir_path = paths["dir_path"] - files = paths["files"] - genome_fasta = os.path.join( - dir_path, files["genome"]) - genomic_features = os.path.join( - dir_path, files["genomic_features"]) - coords_only = os.path.join( - dir_path, files["sample_from_regions"]) - distinct_features = os.path.join( - dir_path, files["distinct_features"]) - - output_dir = paths["output_dir"] - os.makedirs(output_dir, exist_ok=True) - - current_run_output_dir = os.path.join( - output_dir, strftime("%Y-%m-%d-%H-%M-%S")) - os.makedirs(current_run_output_dir) + configs = load_path(arguments[""], instantiate=False) ################################################## # TRAIN MODEL PARAMETERS ################################################## - sampler_info = train_model["sampler"] - model_controller_info = train_model["model_controller"] + model_info = configs["model"] + sampler_info = configs["sampler"] + model_controller_info = configs["model_controller"] - ################################################## - # OTHER ARGS - ################################################## - to_stdout = arguments["--stdout"] - verbosity_level = int(arguments["--verbosity"]) - - initialize_logger( - os.path.join(current_run_output_dir, "{0}.log".format(__name__)), - verbosity=verbosity_level, - stdout_handler=to_stdout) - logger = logging.getLogger("selene") - - t_i = time() - feature_thresholds = None - if "specific_feature_thresholds" in sampler_info.keywords: - feature_thresholds = sampler_info.pop("specific_feature_thresholds") - else: - feature_thresholds = None - if "default_threshold" in sampler_info.keywords: - if feature_thresholds: - feature_thresholds["default"] = sampler_info.pop("default_threshold") - else: - feature_thresholds = sampler_info.pop("default_threshold") - - if feature_thresholds: - sampler_info.bind(feature_thresholds=feature_thresholds) - sampler_info.bind(genome=genome_fasta, - query_feature_data=genomic_features, - distinct_features=distinct_features, - intervals_file=coords_only) sampler = instantiate(sampler_info) - t_i_model = time() torch.manual_seed(1337) torch.cuda.manual_seed_all(1337) + import_model_from = model_info["import_model_from"] + model_class_name = model_info["class"] + use_module = importlib.import_module(import_model_from) + model_class = getattr(use_module, model_class_name) + model = model_class(sampler.sequence_length, sampler.n_features) print(model) + if model_info["non_strand_specific"]["use_module"]: + from models.non_strand_specific_module import NonStrandSpecific + model = NonStrandSpecific( + model, mode=model_info["non_strand_specific"]["mode"]) + checkpoint_info = model_controller_info.pop("checkpoint") checkpoint_resume = checkpoint_info.pop("resume") checkpoint = None if checkpoint_resume: checkpoint_file = checkpoint_info.pop("model_file") - logger.info("Resuming training from checkpoint {0}.".format( + print("Resuming training from checkpoint {0}.".format( checkpoint_file)) checkpoint = torch.load(checkpoint_file) model.load_state_dict(checkpoint["state_dict"]) @@ -135,45 +70,26 @@ criterion = use_module.criterion() optimizer_class, optimizer_args = use_module.get_optimizer(lr) - t_f_model = time() - logger.debug( - "Finished initializing the {0} model from module {1}: {2} s".format( - model.__class__.__name__, - import_model_from, - t_f_model - t_i_model)) - - logger.info(model) - logger.info(optimizer_args) - - - if feature_thresholds: - sampler_info.bind(feature_thresholds=feature_thresholds) - sampler_info.bind(genome=genome_fasta, - query_feature_data=genomic_features, - distinct_features=distinct_features, - intervals_file=coords_only) - sampler = instantiate(sampler_info) - batch_size = model_controller_info.keywords["batch_size"] # Would love to find a better way. max_steps = model_controller_info.keywords["max_steps"] - report_metrics_every_n_steps = \ - model_controller_info.keywords["report_metrics_every_n_steps"] + report_stats_every_n_steps = \ + model_controller_info.keywords["report_stats_every_n_steps"] n_validation_samples = model_controller_info.keywords["n_validation_samples"] model_controller_info.bind(model=model, data_sampler=sampler, loss_criterion=criterion, optimizer_class=optimizer_class, - optimizer_args=optimizer_args, - output_dir=current_run_output_dir) + optimizer_args=optimizer_args) if "optional_args" in model_controller_info.keywords: optional_args = model_controller_info.pop("optional_args") model_controller_info.bind(**optional_args) runner = instantiate(model_controller_info) - logger.info("Training model: {0} steps, {1} batch size.".format( + print("Training model: {0} steps, {1} batch size.".format( max_steps, batch_size)) runner.train_and_validate() - - t_f = time() - logger.info("./train_model.py completed in {0} s.".format(t_f - t_i)) + if configs["evaluate_on_test"]: + runner.evaluate() + if configs["save_datasets"]: + runner.write_datasets_to_file() diff --git a/selene/__init__.py b/selene/__init__.py index 9bad4880..276c3194 100644 --- a/selene/__init__.py +++ b/selene/__init__.py @@ -1,2 +1,2 @@ -__all__ = ["sequences", "targets", "samplers", "utils"] +__all__ = ["predict", "sequences", "targets", "samplers", "utils"] from .model_train import ModelController diff --git a/selene/model_predict.py b/selene/model_predict.py deleted file mode 100644 index 8620ff18..00000000 --- a/selene/model_predict.py +++ /dev/null @@ -1,36 +0,0 @@ -"""TODO: nothing in this file works right now. Please do not use/review -at this time. -""" -import numpy as np -import torch -from torch.autograde import Variable - - -def evaluate_test_sampler(model, - sampler, - batch_size, - features_list, - output_file): - pass - - -def predict_on_encoded_sequences(model, - sequences_mat, - batch_size=64): - predictions = [] - n_examples, _, _ = sequences_mat.shape - for _ in range(0, n_examples, batch_size): - predictions.append(model.forward(sequences_mat[n_examples, :, :])) - return np.vstack(predictions) - -def evaluate_test_sequences(model, - sequences_file, - batch_size, - features_list, - output_file): - pass - - - - - diff --git a/selene/model_train.py b/selene/model_train.py index 6f74c0ac..efa8f749 100644 --- a/selene/model_train.py +++ b/selene/model_train.py @@ -1,54 +1,70 @@ -"""Execute the necessary steps to train the model -""" import logging import math import os import shutil -import sys -from time import time +from time import strftime, time import numpy as np -from sklearn.metrics import roc_auc_score import torch import torch.nn as nn from torch.autograd import Variable from torch.optim.lr_scheduler import ReduceLROnPlateau +from .utils import initialize_logger +from .utils import PerformanceMetrics logger = logging.getLogger("selene") -def initialize_logger(out_filepath, verbosity=1, stdout_handler=False): - """This function can only be called successfully once. - If the logger has already been initialized with handlers, - the function exits. Otherwise, it proceeds to set the - logger configurations. - """ - logger = logging.getLogger("selene") - # check if logger has already been initialized - if len(logger.handlers): - return - - if verbosity == 0: - logger.setLevel(logging.WARN) - elif verbosity == 1: - logger.setLevel(logging.INFO) - elif verbosity == 2: - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - - file_handle = logging.FileHandler(out_filepath) - file_handle.setFormatter(formatter) - logger.addHandler(file_handle) - - if stdout_handler: - stream_handle = logging.StreamHandler(sys.stdout) - stream_handle.setFormatter(formatter) - logger.addHandler(stream_handle) - - class ModelController(object): + """Methods to train and validate a PyTorch model. + + Parameters + ---------- + model : torch.nn.Module + data_sampler : Sampler + loss_criterion : torch.nn._Loss + optimizer_class : + optimizer_args : dict + batch_size : int + Specify the batch size to process examples. Should be a power of 2. + max_steps : int + report_stats_every_n_steps : int + output_dir : str + save_checkpoint_every_n_steps : int|None, optional + Default is 1000. If None, set to the same value as + `report_stats_every_n_steps` + n_validation_samples : int|None, optional + n_test_samples : int|None, optional + cpu_n_threads : int, optional + Default is 32. + use_cuda : bool, optional + Default is False. Specify whether CUDA is available for torch + to use during training. + data_parallel : bool, optional + Default is False. Specify whether multiple GPUs are available + for torch to use during training. + checkpoint_resume : torch.save object, optional + Default is None. If `checkpoint_resume` is not None, assumes + the input is a model saved via `torch.save` that can be + loaded to resume training. + + Attributes + ---------- + model : torch.nn.Module + sampler : Sampler + criterion : torch.nn._Loss + optimizer : torch.optim + batch_size : int + max_steps : int + nth_step_report_stats : int + nth_step_save_checkpoint : int + use_cuda : bool + data_parallel : bool + output_dir : str + training_loss : list(float) + nth_step_stats : dict + """ def __init__(self, model, @@ -58,46 +74,17 @@ def __init__(self, optimizer_args, batch_size, max_steps, - report_metrics_every_n_steps, + report_stats_every_n_steps, output_dir, - n_validation_samples, - save_checkpoint=1000, + save_checkpoint_every_n_steps=1000, + report_gt_feature_n_positives=10, + n_validation_samples=None, + n_test_samples=None, cpu_n_threads=32, use_cuda=False, data_parallel=False, + logging_verbosity=2, checkpoint_resume=None): - """Methods to train and validate a PyTorch model. - - Parameters - ---------- - model : torch.nn.Module - sampler : Sampler - loss_criterion : torch.nn._Loss - optimizer_args : dict - batch_size : int - Specify the batch size to process examples. Should be a power of 2. - use_cuda : bool, optional - Default is False. Specify whether CUDA is available for torch - to use during training. - data_parallel : bool, optional - Default is False. Specify whether multiple GPUs are available - for torch to use during training. - checkpoint_resume : torch.save object, optional - Default is None. If `checkpoint_resume` is not None, assumes - the input is a model saved via `torch.save` that can be - loaded to resume training. - - Attributes - ---------- - model : torch.nn.Module - sampler : Sampler - criterion : torch.nn._Loss - optimizer : torch.optim - batch_size : batch_size - use_cuda : bool - data_parallel : bool - prefix_outputs : str - """ self.model = model self.sampler = data_sampler self.criterion = loss_criterion @@ -106,15 +93,27 @@ def __init__(self, self.batch_size = batch_size self.max_steps = max_steps - self.nth_step_report_metrics = report_metrics_every_n_steps - self.save_checkpoint = save_checkpoint + self.nth_step_report_stats = report_stats_every_n_steps + self.nth_step_save_checkpoint = None + if not save_checkpoint_every_n_steps: + self.nth_step_save_checkpoint = report_stats_every_n_steps + else: + self.nth_step_save_checkpoint = save_checkpoint_every_n_steps torch.set_num_threads(cpu_n_threads) self.use_cuda = use_cuda self.data_parallel = data_parallel - self.output_dir = output_dir - #initialize_logger(os.path.join(output_dir, "{0}.log".format(__name__))) + + os.makedirs(output_dir, exist_ok=True) + current_run_output_dir = os.path.join( + output_dir, strftime("%Y-%m-%d-%H-%M-%S")) + os.makedirs(current_run_output_dir) + self.output_dir = current_run_output_dir + + initialize_logger( + os.path.join(self.output_dir, f"{__name__}.log"), + verbosity=logging_verbosity) if self.data_parallel: self.model = nn.DataParallel(model) @@ -125,34 +124,39 @@ def __init__(self, self.criterion.cuda() logger.debug("Set modules to use CUDA") - self._create_validation_set(n_validation_samples) + self._create_validation_set(n_samples=n_validation_samples) + self._validation_metrics = PerformanceMetrics( + self.sampler.get_feature_from_index, + report_gt_feature_n_positives=report_gt_feature_n_positives) - self.start_step = 0 - self.min_loss = float("inf") + if "test" in self.sampler.modes: + self._create_test_set(n_samples=n_test_samples) + self._test_metrics = PerformanceMetrics( + self.sampler.get_feature_from_index, + report_gt_feature_n_positives=report_gt_feature_n_positives) + + self._start_step = 0 + self._min_loss = float("inf") if checkpoint_resume is not None: - self.start_step = checkpoint_resume["step"] - self.min_loss = checkpoint_resume["min_loss"] + self._start_step = checkpoint_resume["step"] + self._min_loss = checkpoint_resume["min_loss"] self.optimizer.load_state_dict( checkpoint_resume["optimizer"]) logger.info( ("Resuming from checkpoint: " "step {0}, min loss {1}").format( - self.start_step, self.min_loss)) + self._start_step, self._min_loss)) - self.training_loss = [] - self.nth_step_stats = { - "validation_loss": [], - "auc": [] + self.losses = { + "training": [], + "validation": [], } - def _create_validation_set(self, n_validation_samples): - """Used in `__init__`. - """ + def _create_validation_set(self, n_samples=None): t_i = time() self._validation_data, self._all_validation_targets = \ self.sampler.get_validation_set( - self.batch_size, n_samples=n_validation_samples) - print(len(self._validation_data), len(self._all_validation_targets)) + self.batch_size, n_samples=n_samples) t_f = time() logger.info(("{0} s to load {1} validation examples ({2} validation " "batches) to evaluate after each training step.").format( @@ -160,10 +164,19 @@ def _create_validation_set(self, n_validation_samples): len(self._validation_data) * self.batch_size, len(self._validation_data))) + def _create_test_set(self, n_samples=None): + t_i = time() + self._test_data, self._all_test_targets = \ + self.sampler.get_test_set( + self.batch_size, n_samples=n_samples) + t_f = time() + logger.info(("{0} s to load {1} test examples ({2} test batches) " + "to evaluate after all training steps.").format( + t_f - t_i, + len(self._test_data) * self.batch_size, + len(self._test_data))) + def _get_batch(self): - """Sample `self.batch_size` times. Return inputs and targets as a - batch. - """ t_i_sampling = time() batch_sequences, batch_targets = self.sampler.sample( batch_size=self.batch_size) @@ -175,26 +188,26 @@ def _get_batch(self): return (batch_sequences, batch_targets) def train_and_validate(self): - """The training and validation process. - """ logger.info( ("[TRAIN] max_steps: {0}, batch_size: {1}").format( self.max_steps, self.batch_size)) - min_loss = self.min_loss + min_loss = self._min_loss scheduler = ReduceLROnPlateau( self.optimizer, 'max', patience=16, verbose=True, factor=0.8) - for step in range(self.start_step, self.max_steps): + for step in range(self._start_step, self.max_steps): train_loss = self.train() - self.training_loss.append(train_loss) + self.losses["training"].append(train_loss) # @TODO: if step and step % ... - if step % self.nth_step_report_metrics == 0: - validation_loss, auc = self.validate() - self.nth_step_stats["validation_loss"].append(validation_loss) - self.nth_step_stats["auc"].append(auc) - scheduler.step(math.ceil(auc * 1000.0) / 1000.0) + if step % self.nth_step_report_stats == 0: + valid_scores = self.validate() + validation_loss = valid_scores["loss"] + self.losses["training"].append(train_loss) + self.losses["validation"].append(validation_loss) + scheduler.step( + math.ceil(valid_scores["roc_auc"] * 1000.0) / 1000.0) is_best = validation_loss < min_loss min_loss = min(validation_loss, min_loss) @@ -205,11 +218,11 @@ def train_and_validate(self): "min_loss": min_loss, "optimizer": self.optimizer.state_dict()}, is_best) logger.info( - ("[METRICS] step={0}: " + ("[STATS] step={0}: " "Training loss: {1}, validation loss: {2}.").format( step, train_loss, validation_loss)) - if step % self.save_checkpoint == 0: + if step % self.nth_step_save_checkpoint == 0: self._save_checkpoint({ "step": step, "arch": self.model.__class__.__name__, @@ -218,8 +231,6 @@ def train_and_validate(self): "optimizer": self.optimizer.state_dict()}, False) def train(self): - """Create and process a training batch of positive/negative examples. - """ self.model.train() self.sampler.set_mode("train") inputs, targets = self._get_batch() @@ -241,26 +252,15 @@ def train(self): loss.backward() self.optimizer.step() - """ - training_loss = None - def closure(): - predictions = self.model(inputs.transpose(1, 2)) - loss = self.criterion(predictions, targets) - loss.backward() - training_loss = loss.data[0] - return loss - - self.optimizer.zero_grad() - self.optimizer.step(closure) - """ return loss.data[0] - def validate(self): + def _evaluate_on_data(self, data_in_batches): self.model.eval() - validation_losses = [] - collect_predictions = [] - for (inputs, targets) in self._validation_data: + batch_losses = [] + all_predictions = [] + + for (inputs, targets) in data_in_batches: inputs = torch.Tensor(inputs) targets = torch.Tensor(targets) @@ -272,24 +272,52 @@ def validate(self): targets = Variable(targets, volatile=True) predictions = self.model(inputs.transpose(1, 2)) - validation_loss = self.criterion( - predictions, targets).data[0] - - collect_predictions.append(predictions.data.cpu().numpy()) - validation_losses.append(validation_loss) - all_predictions = np.vstack(collect_predictions) - #print(all_predictions.shape) - feature_aucs = [] - for index, feature_preds in enumerate(all_predictions.T): - feature_targets = self._all_validation_targets[:, index] - if len(np.unique(feature_targets)) > 1: - auc = roc_auc_score(feature_targets, feature_preds) - feature_aucs.append(auc) - logger.debug("[METRICS] average AUC: {0}".format(np.average(feature_aucs))) - print("[VALIDATE] average AUC: {0}".format(np.average(feature_aucs))) - - self.nth_step_stats["auc"].append(np.average(feature_aucs)) - return np.average(validation_losses), np.average(feature_aucs) + loss = self.criterion(predictions, targets) + + all_predictions.append(predictions.data.cpu().numpy()) + batch_losses.append(loss.data[0]) + + all_predictions = np.vstack(all_predictions) + return np.average(batch_losses), all_predictions + + def validate(self): + average_loss, all_predictions = self._evaluate_on_data( + self._validation_data) + + average_scores = self._validation_metrics.update( + self._all_validation_targets, all_predictions) + + for name, score in average_scores.items(): + logger.debug(f"[STATS] average {name}: {score}") + print(f"[VALIDATE] average {name}: {score}") + + average_scores["loss"] = average_loss + return average_scores + + def evaluate(self): + average_loss, all_predictions = self._evaluate_on_data( + self._test_data) + + average_scores = self._test_metrics.update( + self._all_test_targets, all_predictions) + + for name, score in average_scores.items(): + logger.debug(f"[STATS] average {name}: {score}") + print(f"[TEST] average {name}: {score}") + + test_performance = os.path.join( + self.output_dir, "test_performance.txt") + feature_scores_dict = self._test_metrics.write_feature_scores_to_file( + test_performance) + + average_scores["loss"] = average_loss + return (average_scores, feature_scores_dict) + + def write_datasets_to_file(self): + data_dir = os.path.join( + self.output_dir, "data") + os.makedirs(data_dir, exist_ok=True) + self.sampler.save_datasets_to_file(data_dir) def _save_checkpoint(self, state, is_best, dir_path=None, diff --git a/selene/predict/__init__.py b/selene/predict/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene/predict/model_predict.py b/selene/predict/model_predict.py new file mode 100644 index 00000000..9baaa1c5 --- /dev/null +++ b/selene/predict/model_predict.py @@ -0,0 +1,172 @@ +import itertools + +import numpy as np +import torch +from torch.autograd import Variable + +from .predict_handlers import DiffScoreHandler, LogitScoreHandler, \ + WritePredictionsHandler +from ..sequences import Genome +from ..sequences import sequence_to_encoding + + +def predict(model, batch_sequences, use_cuda=False): + inputs = torch.Tensor(batch_sequences) + if use_cuda: + inputs = inputs.cuda() + inputs = Variable(inputs, volatile=True) + outputs = model.forward(inputs.transpose(1, 2)) + return outputs + + +def predict_on_encoded_sequences(model, + sequences, + batch_size=64, + use_cuda=False): + predictions = [] + n_examples, _, _ = sequences.shape + + for i in range(0, n_examples, batch_size): + start = i + end = i + batch_size + batch_sequences = sequences[start:end, :, :] + outputs = predict(model, batch_sequences, use_cuda=use_cuda) + predictions.append(outputs.data.cpu().numpy()) + return np.vstack(predictions) + + +def in_silico_mutagenesis_sequences(input_sequence, + mutate_n_bases=1): + """Creates a list containing each mutation that occurs from in silico + mutagenesis across the whole sequence. + + Parameters + ---------- + input_sequence : str + mutate_n_bases : int + + Returns + ------- + list + A list of all possible mutations. Each element in the list is + itself a list of tuples, e.g. [(0, 'T')] if we are only mutating + 1 base at a time. Each tuple is the position to mutate and the base + with which we are replacing the reference base. + + For a sequence of length 1000, mutating 1 base at a time means that + we return a list of length 3000. + """ + sequence_alts = [] + for index, ref in enumerate(input_sequence): + alts = [] + for base in Genome.BASES_ARR: + if base == ref: + continue + alts.append(base) + sequence_alts.append(alts) + + all_mutated_sequences = [] + for indices in itertools.combinations( + range(len(input_sequence)), mutate_n_bases): + pos_mutations = [] + for i in indices: + pos_mutations.append(sequence_alts[i]) + for mutations in itertools.product(*pos_mutations): + all_mutated_sequences.append(list(zip(indices, mutations))) + return all_mutated_sequences + + +def _ism_sample_id(dna_sequence, mutation_information): + positions = [] + refs = [] + alts = [] + for (position, alt) in mutation_information: + positions.append(str(position)) + refs.append(dna_sequence[position]) + alts.append(alt) + return (';'.join(positions), ';'.join(refs), ';'.join(alts)) + + +def in_silico_mutagenesis_predict(model, + batch_size, + sequence, + mutations_list, + use_cuda=False, + reporters=[]): + current_sequence_encoding = sequence_to_encoding( + sequence, Genome.BASE_TO_INDEX) + for i in range(0, len(mutations_list), batch_size): + start = i + end = i + batch_size + + mutated_sequences = np.zeros( + (batch_size, *current_sequence_encoding.shape)) + + batch_ids = [] + for ix, mutation_info in enumerate(mutations_list[start:end]): + mutated_seq = mutate_sequence( + current_sequence_encoding, mutation_info) + mutated_sequences[ix, :, :] = mutated_seq + batch_ids.append(_ism_sample_id(sequence, mutation_info)) + outputs = predict( + model, mutated_sequences, use_cuda=use_cuda).data.cpu().numpy() + + for r in reporters: + r.handle_batch_predictions(outputs, batch_ids) + + for r in reporters: + r.write_to_file() + + +def _reverse_strand(dna_sequence): + reverse_bases = [Genome.COMPLEMENTARY_BASE[b] for b in dna_sequence[::-1]] + return ''.join(reverse_bases) + + +def mutate_sequence(dna_encoded_sequence, mutation_information): + mutated_seq = np.copy(dna_encoded_sequence) + for (position, alt) in mutation_information: + replace_base = Genome.BASE_TO_INDEX[alt] + mutated_seq[position, :] = 0 + mutated_seq[position, replace_base] = 1 + return mutated_seq + + +def in_silico_mutagenesis(model, + batch_size, + input_sequence, + features_list, + save_diffs, + mutate_n_bases=1, + use_cuda=False, + save_logits=None, + save_predictions=None): + mutated_sequences = in_silico_mutagenesis_sequences( + input_sequence, mutate_n_bases=1) + + current_sequence_encoding = sequence_to_encoding( + input_sequence, Genome.BASE_TO_INDEX) + + base_encoding = current_sequence_encoding.reshape( + (1, *current_sequence_encoding.shape)) + base_preds = predict( + model, base_encoding).data.cpu().numpy() + + reporters = [] + nonfeature_cols = ["pos", "ref", "alt"] + if save_diffs: + diff_handler = DiffScoreHandler( + base_preds, features_list, nonfeature_cols, save_diffs) + reporters.append(diff_handler) + if save_logits: + logit_handler = LogitScoreHandler( + base_preds, features_list, nonfeature_cols, save_logits) + reporters.append(logit_handler) + if save_predictions: + preds_handler = WritePredictionsHandler( + features_list, nonfeature_cols, save_predictions) + reporters.append(preds_handler) + + in_silico_mutagenesis_predict( + model, batch_size, input_sequence, mutated_sequences, + use_cuda=use_cuda, reporters=reporters) diff --git a/selene/predict/predict_handlers/__init__.py b/selene/predict/predict_handlers/__init__.py new file mode 100644 index 00000000..1da0ddc0 --- /dev/null +++ b/selene/predict/predict_handlers/__init__.py @@ -0,0 +1,4 @@ +from .handler import PredictionsHandler +from .diff_score_handler import DiffScoreHandler +from .logit_score_handler import LogitScoreHandler +from .write_predictions_handler import WritePredictionsHandler diff --git a/selene/predict/predict_handlers/diff_score_handler.py b/selene/predict/predict_handlers/diff_score_handler.py new file mode 100644 index 00000000..4f34524d --- /dev/null +++ b/selene/predict/predict_handlers/diff_score_handler.py @@ -0,0 +1,33 @@ +import numpy as np + +from .handler import _write_to_file, PredictionsHandler + + +class DiffScoreHandler(PredictionsHandler): + + def __init__(self, + baseline_prediction, + features_list, + nonfeature_columns, + out_filename): + self.baseline_prediction = baseline_prediction + self.column_names = nonfeature_columns + features_list + self.results = [] + self.samples = [] + self.out_filename = out_filename + + def handle_batch_predictions(self, + batch_predictions, + batch_ids): + absolute_diffs = np.abs(self.baseline_prediction - batch_predictions) + self.results.append(absolute_diffs) + self.samples.append(batch_ids) + return absolute_diffs + + def write_to_file(self): + self.results = np.vstack(self.results) + self.samples = np.vstack(self.samples) + _write_to_file(self.results, + self.samples, + self.column_names, + self.out_filename) diff --git a/selene/predict/predict_handlers/handler.py b/selene/predict/predict_handlers/handler.py new file mode 100644 index 00000000..b1f70b2f --- /dev/null +++ b/selene/predict/predict_handlers/handler.py @@ -0,0 +1,39 @@ +""" +This class is the abstract base class for handling model predicions +""" +from abc import ABCMeta +from abc import abstractmethod + + +def _write_to_file(feature_predictions, info_cols, column_names, filename): + with open(filename, 'w+') as file_handle: + file_handle.write("{columns}\n".format( + columns='\t'.join(column_names))) + for info, preds in zip(info_cols, feature_predictions): + feature_cols = '\t'.join( + probabilities_to_string(preds)) + info_cols = '\t'.join(info) + file_handle.write(f"{info_cols}\t{feature_cols}\n") + + +def probabilities_to_string(probabilities): + return ["{:.2e}".format(p) for p in probabilities] + + +class PredictionsHandler(metaclass=ABCMeta): + """ + The base class for handling model predictions. + """ + @abstractmethod + def handle_batch_predictions(self, *args, **kwargs): + """ + Must be able to handle a batch of model predictions. + """ + raise NotImplementedError + + @abstractmethod + def write_to_file(self, *args, **kwargs): + """ + Writes accumulated handler results to file. + """ + raise NotImplementedError diff --git a/selene/predict/predict_handlers/logit_score_handler.py b/selene/predict/predict_handlers/logit_score_handler.py new file mode 100644 index 00000000..a0f4048f --- /dev/null +++ b/selene/predict/predict_handlers/logit_score_handler.py @@ -0,0 +1,34 @@ +import numpy as np +from scipy.special import logit + +from .handler import _write_to_file, PredictionsHandler + + +class LogitScoreHandler(PredictionsHandler): + + def __init__(self, + baseline_prediction, + features_list, + nonfeature_columns, + out_filename): + self.logit_baseline = logit(baseline_prediction) + self.column_names = nonfeature_columns + features_list + self.results = [] + self.samples = [] + self.out_filename = out_filename + + def handle_batch_predictions(self, + batch_predictions, + batch_ids): + absolute_logits = np.abs(self.logit_baseline - logit(batch_predictions)) + self.results.append(absolute_logits) + self.samples.append(batch_ids) + return absolute_logits + + def write_to_file(self): + self.results = np.vstack(self.results) + self.samples = np.vstack(self.samples) + _write_to_file(self.results, + self.samples, + self.column_names, + self.out_filename) diff --git a/selene/predict/predict_handlers/write_predictions_handler.py b/selene/predict/predict_handlers/write_predictions_handler.py new file mode 100644 index 00000000..f9aacb2e --- /dev/null +++ b/selene/predict/predict_handlers/write_predictions_handler.py @@ -0,0 +1,26 @@ +import numpy as np + +from .handler import _write_to_file, PredictionsHandler + +class WritePredictionsHandler(PredictionsHandler): + + def __init__(self, features_list, nonfeature_columns, out_filename): + self.column_names = nonfeature_columns + features_list + self.results = [] + self.samples = [] + self.out_filename = out_filename + + def handle_batch_predictions(self, + batch_predictions, + batch_ids): + self.results.append(batch_predictions) + self.samples.append(batch_ids) + return batch_predictions + + def write_to_file(self): + self.results = np.vstack(self.results) + self.samples = np.vstack(self.samples) + _write_to_file(self.results, + self.samples, + self.column_names, + self.out_filename) diff --git a/selene/predict/tests/test_model_predict.py b/selene/predict/tests/test_model_predict.py new file mode 100644 index 00000000..d3d35bec --- /dev/null +++ b/selene/predict/tests/test_model_predict.py @@ -0,0 +1,47 @@ +import unittest + +import numpy as np + +from selene.predict.model_predict import in_silico_mutagenesis_sequences, \ + in_silico_mutagenesis_predict + + +class TestModelPredict(unittest.TestCase): + + def setUp(self): + self.bases_arr = ['A', 'C', 'G', 'T'] + self.bases_encoding = {'A': 0, 'C': 1, 'G': 2, 'T': 3} + self.input_sequence = "ATCCG" + + def test_in_silico_muta_sequences_single(self): + observed = in_silico_mutagenesis_sequences("ATCCG") + expected = [ + (0, 'C'), (0, 'G'), (0, 'T'), + (1, 'A'), (1, 'C'), (1, 'G'), + (2, 'A'), (2, 'G'), (2, 'T'), + (3, 'A'), (3, 'G'), (3, 'T'), + (4, 'A'), (4, 'C'), (4, 'T')] + + expected_lists = [[e] for e in expected] + self.assertListEqual(observed, expected_lists) + + def test_in_silico_muta_sequences_double(self): + observed = in_silico_mutagenesis_sequences( + "ATC", mutate_n_bases=2) + expected = [ + [(0, 'C'), (1, 'A')], [(0, 'G'), (1, 'A')], [(0, 'T'), (1, 'A')], + [(0, 'C'), (1, 'C')], [(0, 'G'), (1, 'C')], [(0, 'T'), (1, 'C')], + [(0, 'C'), (1, 'G')], [(0, 'G'), (1, 'G')], [(0, 'T'), (1, 'G')], + + [(0, 'C'), (2, 'A')], [(0, 'G'), (2, 'A')], [(0, 'T'), (2, 'A')], + [(0, 'C'), (2, 'G')], [(0, 'G'), (2, 'G')], [(0, 'T'), (2, 'G')], + [(0, 'C'), (2, 'T')], [(0, 'G'), (2, 'T')], [(0, 'T'), (2, 'T')], + + [(1, 'A'), (2, 'A')], [(1, 'C'), (2, 'A')], [(1, 'G'), (2, 'A')], + [(1, 'A'), (2, 'G')], [(1, 'C'), (2, 'G')], [(1, 'G'), (2, 'G')], + [(1, 'A'), (2, 'T')], [(1, 'C'), (2, 'T')], [(1, 'G'), (2, 'T')], + ] + self.assertCountEqual(observed, expected) + +if __name__ == "__main__": + unittest.main() diff --git a/selene/samplers/intervals_sampler.py b/selene/samplers/intervals_sampler.py index d79979cb..e13f0560 100644 --- a/selene/samplers/intervals_sampler.py +++ b/selene/samplers/intervals_sampler.py @@ -1,7 +1,6 @@ from collections import namedtuple import logging import random -from time import time import numpy as np @@ -77,7 +76,8 @@ def __init__(self, sequence_length=1001, center_bin_to_predict=201, feature_thresholds=0.5, - mode="train"): + mode="train", + save_datasets=["test"]): super(IntervalsSampler, self).__init__( genome, query_feature_data, @@ -88,7 +88,8 @@ def __init__(self, sequence_length=sequence_length, center_bin_to_predict=center_bin_to_predict, feature_thresholds=feature_thresholds, - mode="train") + mode="train", + save_datasets=["test"]) self._sample_from_mode = {} self._randcache = {} @@ -182,20 +183,43 @@ def _partition_dataset_chromosome(self, intervals_file): indices=indices, weights=weights) def _retrieve(self, chrom, position): - bin_start = position - self.bin_radius - bin_end = position + self.bin_radius + 1 + bin_start = position - self._start_radius + bin_end = position + self._end_radius retrieved_targets = self.query_feature_data.get_feature_data( chrom, bin_start, bin_end) if np.sum(retrieved_targets) == 0: + logger.info("No features found in region surrounding " + "chr{0} position {1}. Sampling again.".format( + chrom, position)) return None window_start = bin_start - self.surrounding_sequence_radius window_end = bin_end + self.surrounding_sequence_radius strand = self.STRAND_SIDES[random.randint(0, 1)] - retrieved_sequence = \ + retrieved_seq = \ self.genome.get_encoding_from_coords( - "chr{0}".format(chrom), window_start, window_end, strand) - return (retrieved_sequence, retrieved_targets) + f"chr{chrom}", window_start, window_end, strand) + if retrieved_seq.shape[0] == 0: + logger.info("Full sequence centered at chr{0} position {1} " + "could not be retrieved. Sampling again.".format( + chrom, position)) + return None + elif np.sum(retrieved_seq) / float(retrieved_seq.shape[0]) < 0.60: + logger.info("Over 30% of the bases in the sequence centered " + "at chr{0} position {1} are ambiguous ('N'). " + "Sampling again.".format(chrom, position)) + return None + + if self.mode in self.save_datasets: + feature_indices = ';'.join( + [str(f) for f in np.nonzero(retrieved_targets)[0]]) + self.save_datasets[self.mode].append( + [f"chr{chrom}", + window_start, + window_end, + strand, + feature_indices]) + return (retrieved_seq, retrieved_targets) def _update_randcache(self, mode=None): if not mode: @@ -230,21 +254,8 @@ def sample(self, batch_size=1): retrieve_output = self._retrieve(chrom, position) if not retrieve_output: - logger.info("No features found in region surrounding " - "chr{0} position {1}. Sampling again.".format( - chrom, position)) continue seq, seq_targets = retrieve_output - if seq.shape[0] == 0: - logger.info("Full sequence centered at chr{0} position {1} " - "could not be retrieved. Sampling again.".format( - chrom, position)) - continue - elif np.sum(seq) / float(seq.shape[0]) < 0.60: - logger.info("Over 30% of the bases in the sequence centered " - "at chr{0} position {1} are ambiguous ('N'). " - "Sampling again.".format(chrom, position)) - continue sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 @@ -263,7 +274,14 @@ def get_data_and_targets(self, mode, batch_size, n_samples): targets_mat = np.vstack(targets_mat) return sequences_and_targets, targets_mat - def get_validation_set(self, batch_size, n_samples=None): + def get_dataset_in_batches(self, mode, batch_size, n_samples=None): if not n_samples: - n_samples = len(self._sample_from_mode["validate"].indices) - return self.get_data_and_targets("validate", batch_size, n_samples) + n_samples = len(self._sample_from_mode[mode].indices) + return self.get_data_and_targets(mode, batch_size, n_samples) + + def get_validation_set(self, batch_size, n_samples=None): + return self.get_dataset_in_batches( + "validate", batch_size, n_samples=n_samples) + + def get_test_set(self, batch_size, n_samples=None): + return self.get_dataset_in_batches("test", batch_size, n_samples) diff --git a/selene/samplers/online_sampler.py b/selene/samplers/online_sampler.py index 6860e3de..9116cbf9 100644 --- a/selene/samplers/online_sampler.py +++ b/selene/samplers/online_sampler.py @@ -1,4 +1,5 @@ from abc import ABCMeta +import os from .sampler import Sampler from ..sequences import Genome @@ -19,17 +20,17 @@ def __init__(self, sequence_length=1001, center_bin_to_predict=201, feature_thresholds=0.5, - mode="train"): + mode="train", + save_datasets=["test"]): super(OnlineSampler, self).__init__( random_seed=random_seed ) - # @TODO: this could be more flexible. Sequence len and center bin - # len do not necessarily need to be odd numbers... - if sequence_length % 2 == 0 or center_bin_to_predict % 2 == 0: + + if (sequence_length + center_bin_to_predict) % 2 != 0: raise ValueError( - "Both the sequence length and the center bin length " - "should be odd numbers. Sequence length was {0} and " - "bin length was {1}.".format( + "Sequence length of {0} with a center bin length of {1} " + "is invalid. These 2 inputs should both be odd or both be " + "even.".format( sequence_length, center_bin_to_predict)) surrounding_sequence_length = \ @@ -43,13 +44,8 @@ def __init__(self, # specifying a test holdout partition is optional if test_holdout: self.modes.append("test") - # @TODO: make sure that isinstance works in this - # situation if isinstance(validation_holdout, (list,)) and \ isinstance(test_holdout, (list,)): - #if type(validation_holdout) == type(list()) and \ - # type(test_holdout) == type(list()): - print("both are type list") self.validation_holdout = [ str(c) for c in validation_holdout] self.test_holdout = [str(c) for c in test_holdout] @@ -68,8 +64,6 @@ def __init__(self, else: self.test_holdout = None if isinstance(validation_holdout, (list,)): - #if type(validation_holdout) == type(list()): - print("validation holdout is type list") self.validation_holdout = [ str(c) for c in validation_holdout] else: @@ -84,7 +78,12 @@ def __init__(self, self.surrounding_sequence_radius = int( surrounding_sequence_length / 2) self.sequence_length = sequence_length - self.bin_radius = int((center_bin_to_predict - 1) / 2) + self.bin_radius = int(center_bin_to_predict / 2) + self._start_radius = self.bin_radius + if center_bin_to_predict % 2 == 0: + self._end_radius = self.bin_radius + else: + self._end_radius = self.bin_radius + 1 self.genome = Genome(genome) @@ -98,6 +97,10 @@ def __init__(self, query_feature_data, self._features, feature_thresholds=feature_thresholds) + self.save_datasets = {} + for mode in save_datasets: + self.save_datasets[mode] = [] + def get_feature_from_index(self, feature_index): """Returns the feature corresponding to an index in the feature vector. @@ -113,4 +116,19 @@ def get_feature_from_index(self, feature_index): return self.query_feature_data.index_feature_map[feature_index] def get_sequence_from_encoding(self, encoding): + """Gets the string sequence from + """ return self.genome.encoding_to_sequence(encoding) + + def save_datasets_to_file(self, output_dir): + """This likely only works for validation and test right now. + Training data may be too big to store in a list in memory, so + it is a @TODO to be able to save training data coordinates + intermittently. + """ + for mode, samples in self.save_datasets.items(): + filepath = os.path.join(output_dir, f"{mode}_data.bed") + with open(filepath, 'w+') as file_handle: + for cols in samples: + line ='\t'.join([str(c) for c in cols]) + file_handle.write(f"{line}\n") diff --git a/selene/sequences/__init__.py b/selene/sequences/__init__.py index e3c4c4b4..64f1a544 100644 --- a/selene/sequences/__init__.py +++ b/selene/sequences/__init__.py @@ -1,2 +1,4 @@ from .sequence import Sequence +from .sequence import sequence_to_encoding, encoding_to_sequence, \ + get_reverse_encoding from .genome import Genome diff --git a/selene/sequences/_genome.pyx b/selene/sequences/_sequence.pyx similarity index 86% rename from selene/sequences/_genome.pyx rename to selene/sequences/_sequence.pyx index af393568..c98fad73 100644 --- a/selene/sequences/_genome.pyx +++ b/selene/sequences/_sequence.pyx @@ -7,15 +7,13 @@ ctypedef np.float32_t FDTYPE_t @cython.boundscheck(False) # turn off bounds-checking for entire function @cython.wraparound(False) # turn off negative index wrapping for entire function -def _fast_sequence_to_encoding(str sequence, dict bases_encoding): +def _fast_sequence_to_encoding(str sequence, dict base_to_index): cdef int sequence_len = len(sequence) cdef np.ndarray[FDTYPE_t, ndim=2] encoding = np.zeros( (sequence_len, 4), dtype=np.float32) cdef int index cdef str base - sequence = str.upper(sequence) - for index in range(sequence_len): base = sequence[index] if base in bases_encoding: diff --git a/selene/sequences/genome.py b/selene/sequences/genome.py index 6dc36ec2..26a0747e 100644 --- a/selene/sequences/genome.py +++ b/selene/sequences/genome.py @@ -5,47 +5,9 @@ import numpy as np from pyfaidx import Fasta -from .sequence import Sequence -from ._genome import _fast_sequence_to_encoding +from .sequence import Sequence, sequence_to_encoding, encoding_to_sequence -def _sequence_to_encoding(sequence, bases_encoding): - """Converts an input sequence to its one hot encoding. - - Parameters - ---------- - sequence : str - The input sequence of length N. - bases_encoding : dict - each of ('A', 'C', 'G', 'T' or 'U') as keys -> index (0, 1, 2, 3), - specify the position to assign 1/0 when a given base exists/does not - exist at a given position in the sequence. - - Returns - ------- - numpy.ndarray, dtype=bool - The N-by-4 encoding of the sequence. - """ - return _fast_sequence_to_encoding(sequence, bases_encoding) - -def _get_base_index(encoding_row): - for index, val in enumerate(encoding_row): - if val == 0.25: - return -1 - elif val == 1: - return index - return -1 - -def _encoding_to_sequence(encoding, bases_arr): - sequence = [] - for row in encoding: - base_pos = _get_base_index(row) - if base_pos == -1: - sequence.append('N') - else: - sequence.append(bases_arr[base_pos]) - return "".join(sequence) - def _get_sequence_from_coords(len_chrs, genome_sequence, chrom, start, end, strand='+'): """Gets the genomic sequence given the chromosome, sequence start, @@ -85,8 +47,17 @@ def _get_sequence_from_coords(len_chrs, genome_sequence, class Genome(Sequence): BASES_ARR = np.array(['A', 'C', 'G', 'T']) - BASES_DICT = dict( - [(base, index) for index, base in enumerate(BASES_ARR)]) + INDEX_TO_BASE = { + 0: 'A', 1: 'C', 2: 'G', 3: 'T' + } + BASE_TO_INDEX = { + 'A': 0, 'C': 1, 'G': 2, 'T': 3, + 'a': 0, 'c': 1, 'g': 2, 't': 3, + } + COMPLEMENTARY_BASE = { + 'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'N': 'N', + 'a': 'T', 'c': 'G', 'g': 'C', 't': 'A', 'n': 'N' + } def __init__(self, fa_file): """Wrapper class around the pyfaix.Fasta class. @@ -229,7 +200,7 @@ def sequence_to_encoding(self, sequence): numpy.ndarray, dtype=float64 The N-by-4 encoding of the sequence. """ - return _sequence_to_encoding(sequence, self.BASES_DICT) + return sequence_to_encoding(sequence, self.BASE_TO_INDEX) def encoding_to_sequence(self, encoding): """Converts an input encoding to its DNA sequence. @@ -243,4 +214,4 @@ def encoding_to_sequence(self, encoding): ------- str """ - return _encoding_to_sequence(encoding, self.BASES_ARR) + return encoding_to_sequence(encoding, self.BASES_ARR) diff --git a/selene/sequences/sequence.py b/selene/sequences/sequence.py index bf3c0ff9..e3a3d228 100644 --- a/selene/sequences/sequence.py +++ b/selene/sequences/sequence.py @@ -4,6 +4,77 @@ from abc import ABCMeta from abc import abstractmethod +import numpy as np + +from ._sequence import _fast_sequence_to_encoding + +def sequence_to_encoding(sequence, base_to_index): + """Converts an input sequence to its one hot encoding. + + Parameters + ---------- + sequence : str + The input sequence of length N. + base_to_index : dict + each of ('A', 'C', 'G', 'T' or 'U') as keys -> index (0, 1, 2, 3), + specify the position to assign 1/0 when a given base exists/does not + exist at a given position in the sequence. + + Returns + ------- + np.ndarray, dtype=float32 + The N-by-4 encoding of the sequence. + """ + return _fast_sequence_to_encoding(sequence, base_to_index) + +def _get_base_index(encoding_row): + for index, val in enumerate(encoding_row): + if val == 0.25: + return -1 + elif val == 1: + return index + return -1 + +def encoding_to_sequence(encoding, bases_arr): + """Converts a sequence one hot encoding to its string + sequence. + + Parameters + ---------- + encoding : np.ndarray, dtype=float32 + bases_arr : list + each of ('A', 'C', 'G', 'T' or 'U') in the order that + corresponds to the correct columns for those bases in the encoding. + + Returns + ------- + str + """ + sequence = [] + for row in encoding: + base_pos = _get_base_index(row) + if base_pos == -1: + sequence.append('N') + else: + sequence.append(bases_arr[base_pos]) + return "".join(sequence) + +def get_reverse_encoding(encoding, + bases_arr, + base_to_index, + complementary_base): + reverse_encoding = np.zeros(encoding.shape) + for index, row in enumerate(encoding): + base_pos = _get_base_index(row) + if base_pos == -1: + reverse_encoding[index, :] = 0.25 + else: + base = complementary_base[bases_arr[base_pos]] + complem_base_pos = base_to_index[base] + rev_index = encoding.shape[0] - row - 1 + reverse_encoding[rev_index, complem_base_pos] = 1 + return reverse_encoding + class Sequence(metaclass=ABCMeta): """ diff --git a/selene/sequences/tests/test_genome.py b/selene/sequences/tests/test_genome.py index a360c271..44d5f0db 100644 --- a/selene/sequences/tests/test_genome.py +++ b/selene/sequences/tests/test_genome.py @@ -2,8 +2,9 @@ import numpy as np -from selene.sequences.genome import _sequence_to_encoding, \ - _encoding_to_sequence, _get_sequence_from_coords +from selene.sequences.genome import _get_sequence_from_coords +from selene.sequences.sequence import sequence_to_encoding, \ + encoding_to_sequence class TestGenome(unittest.TestCase): @@ -33,9 +34,9 @@ def _genome_sequence(self, chrom, start, end, strand): sequence = base_sequence_neg * repeat_base_seq return sequence[start:end] - def test__sequence_to_encoding(self): + def test_sequence_to_encoding(self): sequence = "ctgCGCAA" - observed = _sequence_to_encoding(sequence, self.bases_encoding) + observed = sequence_to_encoding(sequence, self.bases_encoding) expected = np.array([ [0., 1., 0., 0.], [0., 0., 0., 1.], # ct [0., 0., 1., 0.], [0., 1., 0., 0.], # gC @@ -44,9 +45,9 @@ def test__sequence_to_encoding(self): ]) self.assertSequenceEqual(observed.tolist(), expected.tolist()) - def test__sequence_to_encoding_unknown_bases(self): + def test_sequence_to_encoding_unknown_bases(self): sequence = "AnnUAtCa" - observed = _sequence_to_encoding(sequence, self.bases_encoding) + observed = sequence_to_encoding(sequence, self.bases_encoding) expected = np.array([ [1., 0., 0., 0.], [.25, .25, .25, .25], # An [.25, .25, .25, .25], [.25, .25, .25, .25], # nU @@ -55,21 +56,21 @@ def test__sequence_to_encoding_unknown_bases(self): ]) self.assertSequenceEqual(observed.tolist(), expected.tolist()) - def test__encoding_to_sequence(self): + def test_encoding_to_sequence(self): encoding = np.array([ [1., 0., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 1.], [0., 0., 1., 0.], [0., 1., 0., 0.], [0., 0., 0., 1]]) - observed = _encoding_to_sequence(encoding, self.bases_arr) + observed = encoding_to_sequence(encoding, self.bases_arr) expected = "AATGCT" self.assertEqual(observed, expected) - def test__encoding_to_sequence_unknown_bases(self): + def test_encoding_to_sequence_unknown_bases(self): encoding = np.array([ [0., 0., 1., 0.], [0.25, 0.25, 0.25, 0.25], [1., 0., 0., 0.], [0., 0., 0., 1.], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]) - observed = _encoding_to_sequence(encoding, self.bases_arr) + observed = encoding_to_sequence(encoding, self.bases_arr) expected = "GNATNN" self.assertEqual(observed, expected) diff --git a/selene/targets/genomic_features.py b/selene/targets/genomic_features.py index db7f76c2..cd448e12 100644 --- a/selene/targets/genomic_features.py +++ b/selene/targets/genomic_features.py @@ -48,6 +48,31 @@ def _get_feature_data(query_chrom, query_start, query_end, return _fast_get_feature_data( query_start, query_end, thresholds, feature_index_map, rows) +def _define_feature_thresholds(feature_thresholds, features): + feature_thresholds_dict = {} + feature_thresholds_vec = np.zeros(len(features)) + if isinstance(feature_thresholds, float): + feature_thresholds_dict = dict.fromkeys(features, feature_thresholds) + feature_thresholds_vec += feature_thresholds + elif isinstance(feature_thresholds, dict): + # assign the default value to everything first + feature_thresholds_dict = dict.fromkeys( + features, feature_thresholds["default"]) + feature_thresholds_vec += feature_thresholds["default"] + for i, f in enumerate(features): + if f in feature_thresholds: + feature_thresholds_dict[f] = feature_thresholds[f] + feature_thresholds_vec[i] = feature_thresholds[f] + # this branch will not be accessed if you use a config.yml file to + # specify input parameters + elif isinstance(feature_thresholds, types.FunctionType): + for i, f in enumerate(features): + threshold = feature_thresholds(f) + feature_thresholds_dict[f] = threshold + feature_thresholds_vec[i] = threshold + feature_thresholds_vec = feature_thresholds_vec.astype(np.float32) + return feature_thresholds_dict, feature_thresholds_vec + class GenomicFeatures(Target): def __init__(self, dataset, features, feature_thresholds): @@ -101,25 +126,8 @@ def __init__(self, dataset, features, feature_thresholds): self.index_feature_map = dict(list(enumerate(features))) - self.feature_thresholds = {} - self._feature_thresholds_vec = np.zeros(self.n_features) - if isinstance(feature_thresholds, float): - for i, f in enumerate(features): - self.feature_thresholds[f] = feature_thresholds - self._feature_thresholds_vec[i] = feature_thresholds - elif isinstance(feature_thresholds, dict): - for i, f in enumerate(features): - if f in feature_thresholds: - self.feature_thresholds[f] = feature_thresholds[f] - self._feature_thresholds_vec[i] = feature_thresholds[f] - else: - self.feature_thresholds[f] = feature_thresholds["default"] - self._feature_thresholds_vec[i] = feature_thresholds["default"] - elif isinstance(feature_thresholds, types.FunctionType): - for i, f in enumerate(features): - self.feature_thresholds[f] = feature_thresholds(f) - self._feature_thresholds_vec[i] = feature_thresholds(f) - self._feature_thresholds_vec = self._feature_thresholds_vec.astype(np.float32) + self.feature_thresholds, self._feature_thresholds_vec = \ + _define_feature_thresholds(feature_thresholds, features) def _query_tabix(self, chrom, start, end): try: diff --git a/selene/utils/__init__.py b/selene/utils/__init__.py index 30d01c5e..8acaf7a0 100644 --- a/selene/utils/__init__.py +++ b/selene/utils/__init__.py @@ -1,3 +1,4 @@ -from .utils import initialize_logger, read_yaml_file +from .utils import initialize_logger +from .performance_metrics import PerformanceMetrics from .config import load, load_path, instantiate diff --git a/selene/utils/performance_metrics.py b/selene/utils/performance_metrics.py new file mode 100644 index 00000000..001395ee --- /dev/null +++ b/selene/utils/performance_metrics.py @@ -0,0 +1,87 @@ +from collections import defaultdict, namedtuple + +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score + + +Metric = namedtuple("Metric", ["fn", "data"]) + + +def compute_score(targets, predictions, + compute_score_fn, + report_gt_feature_n_positives=10): + feature_scores = np.ones(targets.shape[1]) * -1 + for index, feature_preds in enumerate(predictions.T): + feature_targets = targets[:, index] + if len(np.unique(feature_targets)) > 1 and \ + np.sum(feature_targets) < report_gt_feature_n_positives: + feature_scores[index] = compute_score_fn( + feature_targets, feature_preds) + + valid_feature_scores = [s for s in feature_scores if s >= 0] + average_score = np.average(valid_feature_scores) + return average_score, feature_scores + +def get_feature_specific_scores(data, get_feature_from_ix_fn): + feature_score_dict = {} + for index, score in enumerate(data): + feature = get_feature_from_ix_fn(index) + if score >= 0: + feature_score_dict[feature] = score + else: + feature_score_dict[feature] = None + return feature_score_dict + +class PerformanceMetrics(object): + """Report metrics in addition to loss + """ + + def __init__(self, + get_feature_from_ix_fn, + report_gt_feature_n_positives=10): + self.skip_threshold = report_gt_feature_n_positives + self.feature_from_ix = get_feature_from_ix_fn + self.metrics = { + "roc_auc": Metric(fn=roc_auc_score, data=[]), + "average_precision": Metric(fn=average_precision_score, data=[]) + } + + def add_metric(self, name, metric_fn): + self.metrics[name] = Metric(fn=metric_fn, data=[]) + + def remove_metric(self, name): + data = self.metrics[name].data + del self.metrics[name] + return data + + def update(self, targets, predictions): + metric_scores = {} + for name, metric in self.metrics.items(): + avg_score, feature_scores = compute_score( + targets, predictions, metric.fn, + report_gt_feature_n_positives=self.skip_threshold) + metric.data.append(feature_scores) + metric_scores[name] = avg_score + return metric_scores + + def write_feature_scores_to_file(self, output_file): + feature_scores = defaultdict(dict) + for name, metric in self.metrics.items(): + feature_score_dict = get_feature_specific_scores( + metric.data[-1], self.feature_from_ix) + for feature, score in feature_score_dict.items(): + feature_scores[feature][name] = score + + metric_cols = [m for m in self.metrics.keys()] + cols = '\t'.join(["features"] + metric_cols) + print(output_file) + print(cols, len(feature_scores)) + with open(output_file, 'w+') as file_handle: + file_handle.write(f"{cols}\n") + for feature, metric_scores in sorted(feature_scores.items()): + metric_score_cols = '\t'.join( + [f"{s:.4f}" for s in metric_scores.values()]) + file_handle.write(f"{feature}\t{metric_score_cols}\n") + + return feature_scores + diff --git a/selene/utils/utils.py b/selene/utils/utils.py index a1902070..a190833b 100644 --- a/selene/utils/utils.py +++ b/selene/utils/utils.py @@ -1,10 +1,10 @@ import logging -import sys -import yaml +VCF_REQUIRED_COLS = ["#CHROM", "POS", "ID", "REF", "ALT"] -def initialize_logger(out_filepath, verbosity=1, stdout_handler=False): + +def initialize_logger(out_filepath, verbosity=2): """This function can only be called successfully once. If the logger has already been initialized with handlers, the function exits. Otherwise, it proceeds to set the @@ -28,16 +28,29 @@ def initialize_logger(out_filepath, verbosity=1, stdout_handler=False): file_handle.setFormatter(formatter) logger.addHandler(file_handle) - if stdout_handler: - stream_handle = logging.StreamHandler(sys.stdout) - stream_handle.setFormatter(formatter) - logger.addHandler(stream_handle) +def read_vcf_file(vcf_file): + """Read the relevant columns for a VCF file to collect variants + for variant effect prediction. + """ + variants = [] + with open(vcf_file, 'r') as file_handle: + for line in file_handle: + if "#CHROM" in line: + cols = line.strip().split('\t') + if cols[:5] != VCF_REQUIRED_COLS: + raise ValueError( + "First 5 columns in file {0} were {1}. " + "Expected columns: {2}".format( + vcf_file, cols[:5], VCF_REQUIRED_COLS)) + break + + for line in file_handle: + cols = line.strip().split('\t') + chrom = str(cols[0]) + pos = int(cols[1]) + ref = cols[3] + alt = cols[4] + variants.append((chrom, pos, ref, alt)) + return variants -def read_yaml_file(config_file): - with open(config_file, "r") as config_file: - try: - config_dict = yaml.load(config_file) - return config_dict - except yaml.YAMLError as exception: - sys.exit(exception) diff --git a/setup.py b/setup.py index f8dc8f88..e7309387 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,8 @@ from Cython.Build import cythonize genome_module = Extension( - "selene.sequences._genome", - ["selene/sequences/_genome.pyx"], + "selene.sequences._sequence", + ["selene/sequences/_sequence.pyx"], include_dirs=[np.get_include()]) genomic_features_module = Extension(