-
Notifications
You must be signed in to change notification settings - Fork 92
Incorporate an evaluate function into model controller & a number of other refactoring steps towards a training and evaluation pipeline. #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b4020a8
225d889
10f95bf
5a68187
f6f3fca
b36c764
894eb82
773e87b
486e470
43eb02f
18325f9
aff1a28
a45610d
2be8c01
fc1da8d
946936d
d738481
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,40 @@ | ||
--- | ||
model: { | ||
non_strand_specific: { | ||
use_module: True, | ||
mode: mean | ||
}, | ||
import_model_from: models.deepsea | ||
class: DeepSEA | ||
} | ||
sampler: !obj:selene.samplers.IntervalsSampler { | ||
genome: /scratch/data_hg/male.hg19.fasta, | ||
genomic_features: /scratch/data_hg/sorted_sv_aggregate.bed.gz, | ||
distinct_features: /scratch/data_hg/distinct_features.txt, | ||
sample_from_regions: /scratch/data_hg/TFs_coords_only.txt, | ||
test_holdout: [8, 9], | ||
validation_holdout: [6, 7], | ||
random_seed: 127, | ||
sequence_length: 1001, | ||
center_bin_to_predict: 201, | ||
default_threshold: 0.5, | ||
feature_thresholds: 0.5, | ||
mode: "train" | ||
} | ||
} | ||
model_controller: !obj:selene.ModelController { | ||
batch_size: 64, | ||
max_steps: 500000, | ||
report_metrics_every_n_steps: 16000, | ||
n_validation_samples: 3200, | ||
report_stats_every_n_steps: 16000, | ||
n_validation_samples: 32000, | ||
optional_args: { | ||
cpu_n_threads: 32, | ||
use_cuda: True, | ||
data_parallel: False | ||
}, | ||
logging_verbosity: 2 | ||
}, | ||
checkpoint: { | ||
resume: False | ||
} | ||
} | ||
}, | ||
output_dir: /tigress/TROYANSKAYA/kathy/example_outputs | ||
} | ||
evaluate_on_test: True | ||
... |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from torch.nn.modules import Module | ||
|
||
|
||
def flip(x, dim): | ||
"""Reverses the elements in a given dimension `dim` of the Tensor. | ||
|
||
source: https://github.com/pytorch/pytorch/issues/229 | ||
""" | ||
xsize = x.size() | ||
dim = x.dim() + dim if dim < 0 else dim | ||
x = x.contiguous() | ||
x = x.view(-1, *xsize[dim:]) | ||
x = x.view( | ||
x.size(0), x.size(1), -1)[:, getattr( | ||
torch.arange(x.size(1)-1, -1, -1), | ||
('cpu','cuda')[x.is_cuda])().long(), :] | ||
return x.view(xsize) | ||
|
||
|
||
class NonStrandSpecific(Module): | ||
def __init__(self, model, mode="mean"): | ||
super(NonStrandSpecific, self).__init__() | ||
|
||
self.model = model | ||
|
||
if mode != "mean" and mode != "max": | ||
raise ValueError(f"Mode should be one of 'mean' or 'max' but was" | ||
"{mode!r}.") | ||
self.mode = mode | ||
|
||
def forward(self, input): | ||
|
||
reverse_input = flip( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is based on the assumption that the sequence is encoded in such a way that we can just "flip" the indices in the matrix and get the reverse sequence. I'll need to document that encoding |
||
flip(input, 1), 2) | ||
|
||
output = self.model.forward(input) | ||
output_from_rev = self.model.forward( | ||
reverse_input) | ||
if self.mode == "mean": | ||
return (output + output_from_rev) / 2 | ||
else: | ||
return torch.max(output, output_from_rev) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,126 +6,61 @@ | |
Saves model to a user-specified output file. | ||
|
||
Usage: | ||
selene.py <import-module> <model-class-name> <lr> | ||
<paths-yml> <train-model-yml> | ||
[-s | --stdout] [--verbosity=<level>] | ||
selene.py <lr> <config-yml> | ||
selene.py -h | --help | ||
|
||
Options: | ||
-h --help Show this screen. | ||
|
||
<import-module> Import the module containing the model | ||
<model-class-name> Must be a model class in the imported module | ||
<lr> Choose the optimizer's learning rate | ||
<paths-yml> Input data and output filepaths | ||
<train-model-yml> Model-specific parameters | ||
-s --stdout Will also output logging information to stdout | ||
[default: False] | ||
--verbosity=<level> Logging verbosity level (0=WARN, 1=INFO, 2=DEBUG) | ||
[default: 1] | ||
<config-yml> Model-specific parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve the documentation here. |
||
""" | ||
import importlib | ||
import logging | ||
import os | ||
from time import strftime, time | ||
|
||
from docopt import docopt | ||
import torch | ||
|
||
from selene.model_train import ModelController | ||
from selene.samplers import IntervalsSampler | ||
from selene.utils import initialize_logger, read_yaml_file | ||
from selene.utils import load, load_path, instantiate | ||
from selene.utils import load_path, instantiate | ||
|
||
if __name__ == "__main__": | ||
arguments = docopt( | ||
__doc__, | ||
version="1.0") | ||
import_model_from = arguments["<import-module>"] | ||
model_class_name = arguments["<model-class-name>"] | ||
use_module = importlib.import_module(import_model_from) | ||
model_class = getattr(use_module, model_class_name) | ||
|
||
lr = float(arguments["<lr>"]) | ||
|
||
paths = read_yaml_file( | ||
arguments["<paths-yml>"]) | ||
|
||
train_model = load_path(arguments["<train-model-yml>"], instantiate=False) | ||
|
||
|
||
################################################## | ||
# PATHS | ||
################################################## | ||
dir_path = paths["dir_path"] | ||
files = paths["files"] | ||
genome_fasta = os.path.join( | ||
dir_path, files["genome"]) | ||
genomic_features = os.path.join( | ||
dir_path, files["genomic_features"]) | ||
coords_only = os.path.join( | ||
dir_path, files["sample_from_regions"]) | ||
distinct_features = os.path.join( | ||
dir_path, files["distinct_features"]) | ||
|
||
output_dir = paths["output_dir"] | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
current_run_output_dir = os.path.join( | ||
output_dir, strftime("%Y-%m-%d-%H-%M-%S")) | ||
os.makedirs(current_run_output_dir) | ||
configs = load_path(arguments["<config-yml>"], instantiate=False) | ||
|
||
################################################## | ||
# TRAIN MODEL PARAMETERS | ||
################################################## | ||
sampler_info = train_model["sampler"] | ||
model_controller_info = train_model["model_controller"] | ||
model_info = configs["model"] | ||
sampler_info = configs["sampler"] | ||
model_controller_info = configs["model_controller"] | ||
|
||
################################################## | ||
# OTHER ARGS | ||
################################################## | ||
to_stdout = arguments["--stdout"] | ||
verbosity_level = int(arguments["--verbosity"]) | ||
|
||
initialize_logger( | ||
os.path.join(current_run_output_dir, "{0}.log".format(__name__)), | ||
verbosity=verbosity_level, | ||
stdout_handler=to_stdout) | ||
logger = logging.getLogger("selene") | ||
|
||
t_i = time() | ||
feature_thresholds = None | ||
if "specific_feature_thresholds" in sampler_info.keywords: | ||
feature_thresholds = sampler_info.pop("specific_feature_thresholds") | ||
else: | ||
feature_thresholds = None | ||
if "default_threshold" in sampler_info.keywords: | ||
if feature_thresholds: | ||
feature_thresholds["default"] = sampler_info.pop("default_threshold") | ||
else: | ||
feature_thresholds = sampler_info.pop("default_threshold") | ||
|
||
if feature_thresholds: | ||
sampler_info.bind(feature_thresholds=feature_thresholds) | ||
sampler_info.bind(genome=genome_fasta, | ||
query_feature_data=genomic_features, | ||
distinct_features=distinct_features, | ||
intervals_file=coords_only) | ||
sampler = instantiate(sampler_info) | ||
|
||
t_i_model = time() | ||
torch.manual_seed(1337) | ||
torch.cuda.manual_seed_all(1337) | ||
|
||
import_model_from = model_info["import_model_from"] | ||
model_class_name = model_info["class"] | ||
use_module = importlib.import_module(import_model_from) | ||
model_class = getattr(use_module, model_class_name) | ||
|
||
model = model_class(sampler.sequence_length, sampler.n_features) | ||
print(model) | ||
|
||
if model_info["non_strand_specific"]["use_module"]: | ||
from models.non_strand_specific_module import NonStrandSpecific | ||
model = NonStrandSpecific( | ||
model, mode=model_info["non_strand_specific"]["mode"]) | ||
|
||
checkpoint_info = model_controller_info.pop("checkpoint") | ||
checkpoint_resume = checkpoint_info.pop("resume") | ||
checkpoint = None | ||
if checkpoint_resume: | ||
checkpoint_file = checkpoint_info.pop("model_file") | ||
logger.info("Resuming training from checkpoint {0}.".format( | ||
print("Resuming training from checkpoint {0}.".format( | ||
checkpoint_file)) | ||
checkpoint = torch.load(checkpoint_file) | ||
model.load_state_dict(checkpoint["state_dict"]) | ||
|
@@ -135,45 +70,26 @@ | |
criterion = use_module.criterion() | ||
optimizer_class, optimizer_args = use_module.get_optimizer(lr) | ||
|
||
t_f_model = time() | ||
logger.debug( | ||
"Finished initializing the {0} model from module {1}: {2} s".format( | ||
model.__class__.__name__, | ||
import_model_from, | ||
t_f_model - t_i_model)) | ||
|
||
logger.info(model) | ||
logger.info(optimizer_args) | ||
|
||
|
||
if feature_thresholds: | ||
sampler_info.bind(feature_thresholds=feature_thresholds) | ||
sampler_info.bind(genome=genome_fasta, | ||
query_feature_data=genomic_features, | ||
distinct_features=distinct_features, | ||
intervals_file=coords_only) | ||
sampler = instantiate(sampler_info) | ||
|
||
batch_size = model_controller_info.keywords["batch_size"] # Would love to find a better way. | ||
max_steps = model_controller_info.keywords["max_steps"] | ||
report_metrics_every_n_steps = \ | ||
model_controller_info.keywords["report_metrics_every_n_steps"] | ||
report_stats_every_n_steps = \ | ||
model_controller_info.keywords["report_stats_every_n_steps"] | ||
n_validation_samples = model_controller_info.keywords["n_validation_samples"] | ||
|
||
model_controller_info.bind(model=model, | ||
data_sampler=sampler, | ||
loss_criterion=criterion, | ||
optimizer_class=optimizer_class, | ||
optimizer_args=optimizer_args, | ||
output_dir=current_run_output_dir) | ||
optimizer_args=optimizer_args) | ||
if "optional_args" in model_controller_info.keywords: | ||
optional_args = model_controller_info.pop("optional_args") | ||
model_controller_info.bind(**optional_args) | ||
runner = instantiate(model_controller_info) | ||
|
||
logger.info("Training model: {0} steps, {1} batch size.".format( | ||
print("Training model: {0} steps, {1} batch size.".format( | ||
max_steps, batch_size)) | ||
runner.train_and_validate() | ||
|
||
t_f = time() | ||
logger.info("./train_model.py completed in {0} s.".format(t_f - t_i)) | ||
if configs["evaluate_on_test"]: | ||
runner.evaluate() | ||
if configs["save_datasets"]: | ||
runner.write_datasets_to_file() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
__all__ = ["sequences", "targets", "samplers", "utils"] | ||
__all__ = ["predict", "sequences", "targets", "samplers", "utils"] | ||
from .model_train import ModelController |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check that the proper handling for this input is in
IntervalsSampler
(orOnlineSampler
).