|
33 | 33 | import torch |
34 | 34 |
|
35 | 35 | from selene.model_train import ModelController |
36 | | -from selene.sampler import IntervalsSampler |
| 36 | +from selene.samplers import IntervalsSampler |
37 | 37 | from selene.utils import initialize_logger, read_yaml_file |
| 38 | +from selene.utils import load, load_path, instantiate |
38 | 39 |
|
39 | 40 | if __name__ == "__main__": |
40 | 41 | arguments = docopt( |
|
49 | 50 |
|
50 | 51 | paths = read_yaml_file( |
51 | 52 | arguments["<paths-yml>"]) |
52 | | - train_model = read_yaml_file( |
53 | | - arguments["<train-model-yml>"]) |
| 53 | + |
| 54 | + train_model = load_path(arguments["<train-model-yml>"], instantiate=False) |
| 55 | + |
54 | 56 |
|
55 | 57 | ################################################## |
56 | 58 | # PATHS |
|
93 | 95 |
|
94 | 96 | t_i = time() |
95 | 97 | feature_thresholds = None |
96 | | - if "specific_feature_thresholds" in sampler_info: |
97 | | - feature_thresholds = sampler_info["specific_feature_thresholds"] |
98 | | - del sampler_info["specific_feature_thresholds"] |
| 98 | + if "specific_feature_thresholds" in sampler_info.keywords: |
| 99 | + feature_thresholds = sampler_info.pop("specific_feature_thresholds") |
99 | 100 | else: |
100 | 101 | feature_thresholds = None |
101 | | - if "default_threshold" in sampler_info: |
| 102 | + if "default_threshold" in sampler_info.keywords: |
102 | 103 | if feature_thresholds: |
103 | | - feature_thresholds["default"] = \ |
104 | | - sampler_info["default_threshold"] |
| 104 | + feature_thresholds["default"] = sampler_info.pop("default_threshold") |
105 | 105 | else: |
106 | | - feature_thresholds = sampler_info["default_threshold"] |
107 | | - del sampler_info["default_threshold"] |
| 106 | + feature_thresholds = sampler_info.pop("default_threshold") |
108 | 107 |
|
109 | 108 | if feature_thresholds: |
110 | | - sampler_info["feature_thresholds"] = feature_thresholds |
111 | | - |
112 | | - sampler = IntervalsSampler( |
113 | | - genome_fasta, |
114 | | - genomic_features, |
115 | | - distinct_features, |
116 | | - coords_only, |
117 | | - **sampler_info) |
| 109 | + sampler_info.bind(feature_thresholds=feature_thresholds) |
| 110 | + sampler_info.bind(genome=genome_fasta, |
| 111 | + query_feature_data=genomic_features, |
| 112 | + distinct_features=distinct_features, |
| 113 | + intervals_file=coords_only) |
| 114 | + sampler = instantiate(sampler_info) |
118 | 115 |
|
119 | 116 | t_i_model = time() |
120 | 117 | torch.manual_seed(1337) |
|
123 | 120 | model = model_class(sampler.sequence_length, sampler.n_features) |
124 | 121 | print(model) |
125 | 122 |
|
126 | | - checkpoint_info = model_controller_info["checkpoint"] |
127 | | - checkpoint_resume = checkpoint_info["resume"] |
| 123 | + checkpoint_info = model_controller_info.pop("checkpoint") |
| 124 | + checkpoint_resume = checkpoint_info.pop("resume") |
128 | 125 | checkpoint = None |
129 | 126 | if checkpoint_resume: |
130 | | - checkpoint_file = checkpoint_info["model_file"] |
| 127 | + checkpoint_file = checkpoint_info.pop("model_file") |
131 | 128 | logger.info("Resuming training from checkpoint {0}.".format( |
132 | 129 | checkpoint_file)) |
133 | 130 | checkpoint = torch.load(checkpoint_file) |
134 | 131 | model.load_state_dict(checkpoint["state_dict"]) |
135 | 132 | model.eval() |
| 133 | + model_controller_info.bind(checkpoint_resume=checkpoint) |
136 | 134 |
|
137 | 135 | criterion = use_module.criterion() |
138 | 136 | optimizer_class, optimizer_args = use_module.get_optimizer(lr) |
|
147 | 145 | logger.info(model) |
148 | 146 | logger.info(optimizer_args) |
149 | 147 |
|
150 | | - batch_size = model_controller_info["batch_size"] |
151 | | - max_steps = model_controller_info["max_steps"] |
| 148 | + |
| 149 | + if feature_thresholds: |
| 150 | + sampler_info.bind(feature_thresholds=feature_thresholds) |
| 151 | + sampler_info.bind(genome=genome_fasta, |
| 152 | + query_feature_data=genomic_features, |
| 153 | + distinct_features=distinct_features, |
| 154 | + intervals_file=coords_only) |
| 155 | + sampler = instantiate(sampler_info) |
| 156 | + |
| 157 | + batch_size = model_controller_info.keywords["batch_size"] # Would love to find a better way. |
| 158 | + max_steps = model_controller_info.keywords["max_steps"] |
152 | 159 | report_metrics_every_n_steps = \ |
153 | | - model_controller_info["report_metrics_every_n_steps"] |
154 | | - n_validation_samples = model_controller_info["n_validation_samples"] |
155 | | - |
156 | | - runner = ModelController( |
157 | | - model, sampler, criterion, optimizer_class, optimizer_args, |
158 | | - batch_size, max_steps, report_metrics_every_n_steps, |
159 | | - current_run_output_dir, |
160 | | - n_validation_samples, |
161 | | - checkpoint_resume=checkpoint, |
162 | | - **model_controller_info["optional_args"]) |
| 160 | + model_controller_info.keywords["report_metrics_every_n_steps"] |
| 161 | + n_validation_samples = model_controller_info.keywords["n_validation_samples"] |
| 162 | + |
| 163 | + model_controller_info.bind(model=model, |
| 164 | + data_sampler=sampler, |
| 165 | + loss_criterion=criterion, |
| 166 | + optimizer_class=optimizer_class, |
| 167 | + optimizer_args=optimizer_args, |
| 168 | + output_dir=current_run_output_dir) |
| 169 | + if "optional_args" in model_controller_info.keywords: |
| 170 | + optional_args = model_controller_info.pop("optional_args") |
| 171 | + model_controller_info.bind(**optional_args) |
| 172 | + runner = instantiate(model_controller_info) |
163 | 173 |
|
164 | 174 | logger.info("Training model: {0} steps, {1} batch size.".format( |
165 | 175 | max_steps, batch_size)) |
|
0 commit comments