Skip to content

Commit 671269e

Browse files
authored
Merge pull request #25 from FunctionLab/train-and-eval
Incorporate an evaluate function into model controller & a number of other refactoring steps towards a training and evaluation pipeline.
2 parents 1c6e856 + d738481 commit 671269e

27 files changed

+935
-434
lines changed

config_examples/parameters.yml

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,40 @@
11
---
2+
model: {
3+
non_strand_specific: {
4+
use_module: True,
5+
mode: mean
6+
},
7+
import_model_from: models.deepsea
8+
class: DeepSEA
9+
}
210
sampler: !obj:selene.samplers.IntervalsSampler {
11+
genome: /scratch/data_hg/male.hg19.fasta,
12+
genomic_features: /scratch/data_hg/sorted_sv_aggregate.bed.gz,
13+
distinct_features: /scratch/data_hg/distinct_features.txt,
14+
sample_from_regions: /scratch/data_hg/TFs_coords_only.txt,
315
test_holdout: [8, 9],
416
validation_holdout: [6, 7],
517
random_seed: 127,
618
sequence_length: 1001,
719
center_bin_to_predict: 201,
8-
default_threshold: 0.5,
20+
feature_thresholds: 0.5,
921
mode: "train"
10-
}
22+
}
1123
model_controller: !obj:selene.ModelController {
1224
batch_size: 64,
1325
max_steps: 500000,
14-
report_metrics_every_n_steps: 16000,
15-
n_validation_samples: 3200,
26+
report_stats_every_n_steps: 16000,
27+
n_validation_samples: 32000,
1628
optional_args: {
1729
cpu_n_threads: 32,
1830
use_cuda: True,
1931
data_parallel: False
20-
},
32+
logging_verbosity: 2
33+
},
2134
checkpoint: {
2235
resume: False
23-
}
24-
}
36+
},
37+
output_dir: /tigress/TROYANSKAYA/kathy/example_outputs
38+
}
39+
evaluate_on_test: True
2540
...

config_examples/paths.yml

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from torch.nn.modules import Module
3+
4+
5+
def flip(x, dim):
6+
"""Reverses the elements in a given dimension `dim` of the Tensor.
7+
8+
source: https://github.com/pytorch/pytorch/issues/229
9+
"""
10+
xsize = x.size()
11+
dim = x.dim() + dim if dim < 0 else dim
12+
x = x.contiguous()
13+
x = x.view(-1, *xsize[dim:])
14+
x = x.view(
15+
x.size(0), x.size(1), -1)[:, getattr(
16+
torch.arange(x.size(1)-1, -1, -1),
17+
('cpu','cuda')[x.is_cuda])().long(), :]
18+
return x.view(xsize)
19+
20+
21+
class NonStrandSpecific(Module):
22+
def __init__(self, model, mode="mean"):
23+
super(NonStrandSpecific, self).__init__()
24+
25+
self.model = model
26+
27+
if mode != "mean" and mode != "max":
28+
raise ValueError(f"Mode should be one of 'mean' or 'max' but was"
29+
"{mode!r}.")
30+
self.mode = mode
31+
32+
def forward(self, input):
33+
34+
reverse_input = flip(
35+
flip(input, 1), 2)
36+
37+
output = self.model.forward(input)
38+
output_from_rev = self.model.forward(
39+
reverse_input)
40+
if self.mode == "mean":
41+
return (output + output_from_rev) / 2
42+
else:
43+
return torch.max(output, output_from_rev)
44+

selene.py

Lines changed: 26 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -6,126 +6,61 @@
66
Saves model to a user-specified output file.
77
88
Usage:
9-
selene.py <import-module> <model-class-name> <lr>
10-
<paths-yml> <train-model-yml>
11-
[-s | --stdout] [--verbosity=<level>]
9+
selene.py <lr> <config-yml>
1210
selene.py -h | --help
1311
1412
Options:
1513
-h --help Show this screen.
1614
17-
<import-module> Import the module containing the model
18-
<model-class-name> Must be a model class in the imported module
1915
<lr> Choose the optimizer's learning rate
20-
<paths-yml> Input data and output filepaths
21-
<train-model-yml> Model-specific parameters
22-
-s --stdout Will also output logging information to stdout
23-
[default: False]
24-
--verbosity=<level> Logging verbosity level (0=WARN, 1=INFO, 2=DEBUG)
25-
[default: 1]
16+
<config-yml> Model-specific parameters
2617
"""
2718
import importlib
28-
import logging
29-
import os
30-
from time import strftime, time
3119

3220
from docopt import docopt
3321
import torch
3422

35-
from selene.model_train import ModelController
36-
from selene.samplers import IntervalsSampler
37-
from selene.utils import initialize_logger, read_yaml_file
38-
from selene.utils import load, load_path, instantiate
23+
from selene.utils import load_path, instantiate
3924

4025
if __name__ == "__main__":
4126
arguments = docopt(
4227
__doc__,
4328
version="1.0")
44-
import_model_from = arguments["<import-module>"]
45-
model_class_name = arguments["<model-class-name>"]
46-
use_module = importlib.import_module(import_model_from)
47-
model_class = getattr(use_module, model_class_name)
48-
4929
lr = float(arguments["<lr>"])
5030

51-
paths = read_yaml_file(
52-
arguments["<paths-yml>"])
53-
54-
train_model = load_path(arguments["<train-model-yml>"], instantiate=False)
55-
56-
57-
##################################################
58-
# PATHS
59-
##################################################
60-
dir_path = paths["dir_path"]
61-
files = paths["files"]
62-
genome_fasta = os.path.join(
63-
dir_path, files["genome"])
64-
genomic_features = os.path.join(
65-
dir_path, files["genomic_features"])
66-
coords_only = os.path.join(
67-
dir_path, files["sample_from_regions"])
68-
distinct_features = os.path.join(
69-
dir_path, files["distinct_features"])
70-
71-
output_dir = paths["output_dir"]
72-
os.makedirs(output_dir, exist_ok=True)
73-
74-
current_run_output_dir = os.path.join(
75-
output_dir, strftime("%Y-%m-%d-%H-%M-%S"))
76-
os.makedirs(current_run_output_dir)
31+
configs = load_path(arguments["<config-yml>"], instantiate=False)
7732

7833
##################################################
7934
# TRAIN MODEL PARAMETERS
8035
##################################################
81-
sampler_info = train_model["sampler"]
82-
model_controller_info = train_model["model_controller"]
36+
model_info = configs["model"]
37+
sampler_info = configs["sampler"]
38+
model_controller_info = configs["model_controller"]
8339

84-
##################################################
85-
# OTHER ARGS
86-
##################################################
87-
to_stdout = arguments["--stdout"]
88-
verbosity_level = int(arguments["--verbosity"])
89-
90-
initialize_logger(
91-
os.path.join(current_run_output_dir, "{0}.log".format(__name__)),
92-
verbosity=verbosity_level,
93-
stdout_handler=to_stdout)
94-
logger = logging.getLogger("selene")
95-
96-
t_i = time()
97-
feature_thresholds = None
98-
if "specific_feature_thresholds" in sampler_info.keywords:
99-
feature_thresholds = sampler_info.pop("specific_feature_thresholds")
100-
else:
101-
feature_thresholds = None
102-
if "default_threshold" in sampler_info.keywords:
103-
if feature_thresholds:
104-
feature_thresholds["default"] = sampler_info.pop("default_threshold")
105-
else:
106-
feature_thresholds = sampler_info.pop("default_threshold")
107-
108-
if feature_thresholds:
109-
sampler_info.bind(feature_thresholds=feature_thresholds)
110-
sampler_info.bind(genome=genome_fasta,
111-
query_feature_data=genomic_features,
112-
distinct_features=distinct_features,
113-
intervals_file=coords_only)
11440
sampler = instantiate(sampler_info)
11541

116-
t_i_model = time()
11742
torch.manual_seed(1337)
11843
torch.cuda.manual_seed_all(1337)
11944

45+
import_model_from = model_info["import_model_from"]
46+
model_class_name = model_info["class"]
47+
use_module = importlib.import_module(import_model_from)
48+
model_class = getattr(use_module, model_class_name)
49+
12050
model = model_class(sampler.sequence_length, sampler.n_features)
12151
print(model)
12252

53+
if model_info["non_strand_specific"]["use_module"]:
54+
from models.non_strand_specific_module import NonStrandSpecific
55+
model = NonStrandSpecific(
56+
model, mode=model_info["non_strand_specific"]["mode"])
57+
12358
checkpoint_info = model_controller_info.pop("checkpoint")
12459
checkpoint_resume = checkpoint_info.pop("resume")
12560
checkpoint = None
12661
if checkpoint_resume:
12762
checkpoint_file = checkpoint_info.pop("model_file")
128-
logger.info("Resuming training from checkpoint {0}.".format(
63+
print("Resuming training from checkpoint {0}.".format(
12964
checkpoint_file))
13065
checkpoint = torch.load(checkpoint_file)
13166
model.load_state_dict(checkpoint["state_dict"])
@@ -135,45 +70,26 @@
13570
criterion = use_module.criterion()
13671
optimizer_class, optimizer_args = use_module.get_optimizer(lr)
13772

138-
t_f_model = time()
139-
logger.debug(
140-
"Finished initializing the {0} model from module {1}: {2} s".format(
141-
model.__class__.__name__,
142-
import_model_from,
143-
t_f_model - t_i_model))
144-
145-
logger.info(model)
146-
logger.info(optimizer_args)
147-
148-
149-
if feature_thresholds:
150-
sampler_info.bind(feature_thresholds=feature_thresholds)
151-
sampler_info.bind(genome=genome_fasta,
152-
query_feature_data=genomic_features,
153-
distinct_features=distinct_features,
154-
intervals_file=coords_only)
155-
sampler = instantiate(sampler_info)
156-
15773
batch_size = model_controller_info.keywords["batch_size"] # Would love to find a better way.
15874
max_steps = model_controller_info.keywords["max_steps"]
159-
report_metrics_every_n_steps = \
160-
model_controller_info.keywords["report_metrics_every_n_steps"]
75+
report_stats_every_n_steps = \
76+
model_controller_info.keywords["report_stats_every_n_steps"]
16177
n_validation_samples = model_controller_info.keywords["n_validation_samples"]
16278

16379
model_controller_info.bind(model=model,
16480
data_sampler=sampler,
16581
loss_criterion=criterion,
16682
optimizer_class=optimizer_class,
167-
optimizer_args=optimizer_args,
168-
output_dir=current_run_output_dir)
83+
optimizer_args=optimizer_args)
16984
if "optional_args" in model_controller_info.keywords:
17085
optional_args = model_controller_info.pop("optional_args")
17186
model_controller_info.bind(**optional_args)
17287
runner = instantiate(model_controller_info)
17388

174-
logger.info("Training model: {0} steps, {1} batch size.".format(
89+
print("Training model: {0} steps, {1} batch size.".format(
17590
max_steps, batch_size))
17691
runner.train_and_validate()
177-
178-
t_f = time()
179-
logger.info("./train_model.py completed in {0} s.".format(t_f - t_i))
92+
if configs["evaluate_on_test"]:
93+
runner.evaluate()
94+
if configs["save_datasets"]:
95+
runner.write_datasets_to_file()

selene/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__all__ = ["sequences", "targets", "samplers", "utils"]
1+
__all__ = ["predict", "sequences", "targets", "samplers", "utils"]
22
from .model_train import ModelController

selene/model_predict.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)