|
6 | 6 | Saves model to a user-specified output file. |
7 | 7 |
|
8 | 8 | Usage: |
9 | | - selene.py <import-module> <model-class-name> <lr> |
10 | | - <paths-yml> <train-model-yml> |
11 | | - [-s | --stdout] [--verbosity=<level>] |
| 9 | + selene.py <lr> <config-yml> |
12 | 10 | selene.py -h | --help |
13 | 11 |
|
14 | 12 | Options: |
15 | 13 | -h --help Show this screen. |
16 | 14 |
|
17 | | - <import-module> Import the module containing the model |
18 | | - <model-class-name> Must be a model class in the imported module |
19 | 15 | <lr> Choose the optimizer's learning rate |
20 | | - <paths-yml> Input data and output filepaths |
21 | | - <train-model-yml> Model-specific parameters |
22 | | - -s --stdout Will also output logging information to stdout |
23 | | - [default: False] |
24 | | - --verbosity=<level> Logging verbosity level (0=WARN, 1=INFO, 2=DEBUG) |
25 | | - [default: 1] |
| 16 | + <config-yml> Model-specific parameters |
26 | 17 | """ |
27 | 18 | import importlib |
28 | | -import logging |
29 | | -import os |
30 | | -from time import strftime, time |
31 | 19 |
|
32 | 20 | from docopt import docopt |
33 | 21 | import torch |
34 | 22 |
|
35 | | -from selene.model_train import ModelController |
36 | | -from selene.samplers import IntervalsSampler |
37 | | -from selene.utils import initialize_logger, read_yaml_file |
38 | | -from selene.utils import load, load_path, instantiate |
| 23 | +from selene.utils import load_path, instantiate |
39 | 24 |
|
40 | 25 | if __name__ == "__main__": |
41 | 26 | arguments = docopt( |
42 | 27 | __doc__, |
43 | 28 | version="1.0") |
44 | | - import_model_from = arguments["<import-module>"] |
45 | | - model_class_name = arguments["<model-class-name>"] |
46 | | - use_module = importlib.import_module(import_model_from) |
47 | | - model_class = getattr(use_module, model_class_name) |
48 | | - |
49 | 29 | lr = float(arguments["<lr>"]) |
50 | 30 |
|
51 | | - paths = read_yaml_file( |
52 | | - arguments["<paths-yml>"]) |
53 | | - |
54 | | - train_model = load_path(arguments["<train-model-yml>"], instantiate=False) |
55 | | - |
56 | | - |
57 | | - ################################################## |
58 | | - # PATHS |
59 | | - ################################################## |
60 | | - dir_path = paths["dir_path"] |
61 | | - files = paths["files"] |
62 | | - genome_fasta = os.path.join( |
63 | | - dir_path, files["genome"]) |
64 | | - genomic_features = os.path.join( |
65 | | - dir_path, files["genomic_features"]) |
66 | | - coords_only = os.path.join( |
67 | | - dir_path, files["sample_from_regions"]) |
68 | | - distinct_features = os.path.join( |
69 | | - dir_path, files["distinct_features"]) |
70 | | - |
71 | | - output_dir = paths["output_dir"] |
72 | | - os.makedirs(output_dir, exist_ok=True) |
73 | | - |
74 | | - current_run_output_dir = os.path.join( |
75 | | - output_dir, strftime("%Y-%m-%d-%H-%M-%S")) |
76 | | - os.makedirs(current_run_output_dir) |
| 31 | + configs = load_path(arguments["<config-yml>"], instantiate=False) |
77 | 32 |
|
78 | 33 | ################################################## |
79 | 34 | # TRAIN MODEL PARAMETERS |
80 | 35 | ################################################## |
81 | | - sampler_info = train_model["sampler"] |
82 | | - model_controller_info = train_model["model_controller"] |
| 36 | + model_info = configs["model"] |
| 37 | + sampler_info = configs["sampler"] |
| 38 | + model_controller_info = configs["model_controller"] |
83 | 39 |
|
84 | | - ################################################## |
85 | | - # OTHER ARGS |
86 | | - ################################################## |
87 | | - to_stdout = arguments["--stdout"] |
88 | | - verbosity_level = int(arguments["--verbosity"]) |
89 | | - |
90 | | - initialize_logger( |
91 | | - os.path.join(current_run_output_dir, "{0}.log".format(__name__)), |
92 | | - verbosity=verbosity_level, |
93 | | - stdout_handler=to_stdout) |
94 | | - logger = logging.getLogger("selene") |
95 | | - |
96 | | - t_i = time() |
97 | | - feature_thresholds = None |
98 | | - if "specific_feature_thresholds" in sampler_info.keywords: |
99 | | - feature_thresholds = sampler_info.pop("specific_feature_thresholds") |
100 | | - else: |
101 | | - feature_thresholds = None |
102 | | - if "default_threshold" in sampler_info.keywords: |
103 | | - if feature_thresholds: |
104 | | - feature_thresholds["default"] = sampler_info.pop("default_threshold") |
105 | | - else: |
106 | | - feature_thresholds = sampler_info.pop("default_threshold") |
107 | | - |
108 | | - if feature_thresholds: |
109 | | - sampler_info.bind(feature_thresholds=feature_thresholds) |
110 | | - sampler_info.bind(genome=genome_fasta, |
111 | | - query_feature_data=genomic_features, |
112 | | - distinct_features=distinct_features, |
113 | | - intervals_file=coords_only) |
114 | 40 | sampler = instantiate(sampler_info) |
115 | 41 |
|
116 | | - t_i_model = time() |
117 | 42 | torch.manual_seed(1337) |
118 | 43 | torch.cuda.manual_seed_all(1337) |
119 | 44 |
|
| 45 | + import_model_from = model_info["import_model_from"] |
| 46 | + model_class_name = model_info["class"] |
| 47 | + use_module = importlib.import_module(import_model_from) |
| 48 | + model_class = getattr(use_module, model_class_name) |
| 49 | + |
120 | 50 | model = model_class(sampler.sequence_length, sampler.n_features) |
121 | 51 | print(model) |
122 | 52 |
|
| 53 | + if model_info["non_strand_specific"]["use_module"]: |
| 54 | + from models.non_strand_specific_module import NonStrandSpecific |
| 55 | + model = NonStrandSpecific( |
| 56 | + model, mode=model_info["non_strand_specific"]["mode"]) |
| 57 | + |
123 | 58 | checkpoint_info = model_controller_info.pop("checkpoint") |
124 | 59 | checkpoint_resume = checkpoint_info.pop("resume") |
125 | 60 | checkpoint = None |
126 | 61 | if checkpoint_resume: |
127 | 62 | checkpoint_file = checkpoint_info.pop("model_file") |
128 | | - logger.info("Resuming training from checkpoint {0}.".format( |
| 63 | + print("Resuming training from checkpoint {0}.".format( |
129 | 64 | checkpoint_file)) |
130 | 65 | checkpoint = torch.load(checkpoint_file) |
131 | 66 | model.load_state_dict(checkpoint["state_dict"]) |
|
135 | 70 | criterion = use_module.criterion() |
136 | 71 | optimizer_class, optimizer_args = use_module.get_optimizer(lr) |
137 | 72 |
|
138 | | - t_f_model = time() |
139 | | - logger.debug( |
140 | | - "Finished initializing the {0} model from module {1}: {2} s".format( |
141 | | - model.__class__.__name__, |
142 | | - import_model_from, |
143 | | - t_f_model - t_i_model)) |
144 | | - |
145 | | - logger.info(model) |
146 | | - logger.info(optimizer_args) |
147 | | - |
148 | | - |
149 | | - if feature_thresholds: |
150 | | - sampler_info.bind(feature_thresholds=feature_thresholds) |
151 | | - sampler_info.bind(genome=genome_fasta, |
152 | | - query_feature_data=genomic_features, |
153 | | - distinct_features=distinct_features, |
154 | | - intervals_file=coords_only) |
155 | | - sampler = instantiate(sampler_info) |
156 | | - |
157 | 73 | batch_size = model_controller_info.keywords["batch_size"] # Would love to find a better way. |
158 | 74 | max_steps = model_controller_info.keywords["max_steps"] |
159 | | - report_metrics_every_n_steps = \ |
160 | | - model_controller_info.keywords["report_metrics_every_n_steps"] |
| 75 | + report_stats_every_n_steps = \ |
| 76 | + model_controller_info.keywords["report_stats_every_n_steps"] |
161 | 77 | n_validation_samples = model_controller_info.keywords["n_validation_samples"] |
162 | 78 |
|
163 | 79 | model_controller_info.bind(model=model, |
164 | 80 | data_sampler=sampler, |
165 | 81 | loss_criterion=criterion, |
166 | 82 | optimizer_class=optimizer_class, |
167 | | - optimizer_args=optimizer_args, |
168 | | - output_dir=current_run_output_dir) |
| 83 | + optimizer_args=optimizer_args) |
169 | 84 | if "optional_args" in model_controller_info.keywords: |
170 | 85 | optional_args = model_controller_info.pop("optional_args") |
171 | 86 | model_controller_info.bind(**optional_args) |
172 | 87 | runner = instantiate(model_controller_info) |
173 | 88 |
|
174 | | - logger.info("Training model: {0} steps, {1} batch size.".format( |
| 89 | + print("Training model: {0} steps, {1} batch size.".format( |
175 | 90 | max_steps, batch_size)) |
176 | 91 | runner.train_and_validate() |
177 | | - |
178 | | - t_f = time() |
179 | | - logger.info("./train_model.py completed in {0} s.".format(t_f - t_i)) |
| 92 | + if configs["evaluate_on_test"]: |
| 93 | + runner.evaluate() |
| 94 | + if configs["save_datasets"]: |
| 95 | + runner.write_datasets_to_file() |
0 commit comments