Skip to content

Commit 1c6e856

Browse files
authored
Merge pull request #24 from evancofer/master
Configs can now use new user content
2 parents cd8770b + 42f4580 commit 1c6e856

File tree

7 files changed

+493
-70
lines changed

7 files changed

+493
-70
lines changed

config_examples/parameters.yml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
---
2-
sampler:
3-
test_holdout:
4-
- 8
5-
- 9
6-
validation_holdout:
7-
- 6
8-
- 7
9-
random_seed: 127
10-
sequence_length: 1001
11-
center_bin_to_predict: 201
12-
default_threshold: 0.5
13-
mode: train
14-
model_controller:
15-
batch_size: 64
16-
max_steps: 500000
17-
report_metrics_every_n_steps: 16000
18-
n_validation_samples: 3200
19-
optional_args:
20-
cpu_n_threads: 32
21-
use_cuda: True
2+
sampler: !obj:selene.samplers.IntervalsSampler {
3+
test_holdout: [8, 9],
4+
validation_holdout: [6, 7],
5+
random_seed: 127,
6+
sequence_length: 1001,
7+
center_bin_to_predict: 201,
8+
default_threshold: 0.5,
9+
mode: "train"
10+
}
11+
model_controller: !obj:selene.ModelController {
12+
batch_size: 64,
13+
max_steps: 500000,
14+
report_metrics_every_n_steps: 16000,
15+
n_validation_samples: 3200,
16+
optional_args: {
17+
cpu_n_threads: 32,
18+
use_cuda: True,
2219
data_parallel: False
23-
checkpoint:
20+
},
21+
checkpoint: {
2422
resume: False
23+
}
24+
}
2525
...

selene.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
import torch
3434

3535
from selene.model_train import ModelController
36-
from selene.sampler import IntervalsSampler
36+
from selene.samplers import IntervalsSampler
3737
from selene.utils import initialize_logger, read_yaml_file
38+
from selene.utils import load, load_path, instantiate
3839

3940
if __name__ == "__main__":
4041
arguments = docopt(
@@ -49,8 +50,9 @@
4950

5051
paths = read_yaml_file(
5152
arguments["<paths-yml>"])
52-
train_model = read_yaml_file(
53-
arguments["<train-model-yml>"])
53+
54+
train_model = load_path(arguments["<train-model-yml>"], instantiate=False)
55+
5456

5557
##################################################
5658
# PATHS
@@ -93,28 +95,23 @@
9395

9496
t_i = time()
9597
feature_thresholds = None
96-
if "specific_feature_thresholds" in sampler_info:
97-
feature_thresholds = sampler_info["specific_feature_thresholds"]
98-
del sampler_info["specific_feature_thresholds"]
98+
if "specific_feature_thresholds" in sampler_info.keywords:
99+
feature_thresholds = sampler_info.pop("specific_feature_thresholds")
99100
else:
100101
feature_thresholds = None
101-
if "default_threshold" in sampler_info:
102+
if "default_threshold" in sampler_info.keywords:
102103
if feature_thresholds:
103-
feature_thresholds["default"] = \
104-
sampler_info["default_threshold"]
104+
feature_thresholds["default"] = sampler_info.pop("default_threshold")
105105
else:
106-
feature_thresholds = sampler_info["default_threshold"]
107-
del sampler_info["default_threshold"]
106+
feature_thresholds = sampler_info.pop("default_threshold")
108107

109108
if feature_thresholds:
110-
sampler_info["feature_thresholds"] = feature_thresholds
111-
112-
sampler = IntervalsSampler(
113-
genome_fasta,
114-
genomic_features,
115-
distinct_features,
116-
coords_only,
117-
**sampler_info)
109+
sampler_info.bind(feature_thresholds=feature_thresholds)
110+
sampler_info.bind(genome=genome_fasta,
111+
query_feature_data=genomic_features,
112+
distinct_features=distinct_features,
113+
intervals_file=coords_only)
114+
sampler = instantiate(sampler_info)
118115

119116
t_i_model = time()
120117
torch.manual_seed(1337)
@@ -123,16 +120,17 @@
123120
model = model_class(sampler.sequence_length, sampler.n_features)
124121
print(model)
125122

126-
checkpoint_info = model_controller_info["checkpoint"]
127-
checkpoint_resume = checkpoint_info["resume"]
123+
checkpoint_info = model_controller_info.pop("checkpoint")
124+
checkpoint_resume = checkpoint_info.pop("resume")
128125
checkpoint = None
129126
if checkpoint_resume:
130-
checkpoint_file = checkpoint_info["model_file"]
127+
checkpoint_file = checkpoint_info.pop("model_file")
131128
logger.info("Resuming training from checkpoint {0}.".format(
132129
checkpoint_file))
133130
checkpoint = torch.load(checkpoint_file)
134131
model.load_state_dict(checkpoint["state_dict"])
135132
model.eval()
133+
model_controller_info.bind(checkpoint_resume=checkpoint)
136134

137135
criterion = use_module.criterion()
138136
optimizer_class, optimizer_args = use_module.get_optimizer(lr)
@@ -147,19 +145,31 @@
147145
logger.info(model)
148146
logger.info(optimizer_args)
149147

150-
batch_size = model_controller_info["batch_size"]
151-
max_steps = model_controller_info["max_steps"]
148+
149+
if feature_thresholds:
150+
sampler_info.bind(feature_thresholds=feature_thresholds)
151+
sampler_info.bind(genome=genome_fasta,
152+
query_feature_data=genomic_features,
153+
distinct_features=distinct_features,
154+
intervals_file=coords_only)
155+
sampler = instantiate(sampler_info)
156+
157+
batch_size = model_controller_info.keywords["batch_size"] # Would love to find a better way.
158+
max_steps = model_controller_info.keywords["max_steps"]
152159
report_metrics_every_n_steps = \
153-
model_controller_info["report_metrics_every_n_steps"]
154-
n_validation_samples = model_controller_info["n_validation_samples"]
155-
156-
runner = ModelController(
157-
model, sampler, criterion, optimizer_class, optimizer_args,
158-
batch_size, max_steps, report_metrics_every_n_steps,
159-
current_run_output_dir,
160-
n_validation_samples,
161-
checkpoint_resume=checkpoint,
162-
**model_controller_info["optional_args"])
160+
model_controller_info.keywords["report_metrics_every_n_steps"]
161+
n_validation_samples = model_controller_info.keywords["n_validation_samples"]
162+
163+
model_controller_info.bind(model=model,
164+
data_sampler=sampler,
165+
loss_criterion=criterion,
166+
optimizer_class=optimizer_class,
167+
optimizer_args=optimizer_args,
168+
output_dir=current_run_output_dir)
169+
if "optional_args" in model_controller_info.keywords:
170+
optional_args = model_controller_info.pop("optional_args")
171+
model_controller_info.bind(**optional_args)
172+
runner = instantiate(model_controller_info)
163173

164174
logger.info("Training model: {0} steps, {1} batch size.".format(
165175
max_steps, batch_size))

selene/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
__all__ = ["sequences", "targets", "samplers", "model_controller"]
1+
__all__ = ["sequences", "targets", "samplers", "utils"]
2+
from .model_train import ModelController

selene/samplers/online_sampler.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,3 @@ def get_feature_from_index(self, feature_index):
114114

115115
def get_sequence_from_encoding(self, encoding):
116116
return self.genome.encoding_to_sequence(encoding)
117-
118-
# @abstractmethod
119-
# def sample(self, batch_size):
120-
# raise NotImplementedError
121-
#
122-
# @abstractmethod
123-
# def get_data_and_targets(self, mode, batch_size, n_samples):
124-
# raise NotImplementedError
125-
#
126-
# @abstractmethod
127-
# def get_validation_set(self, batch_size, n_samples=None):
128-
# raise NotImplementedError

selene/targets/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ class Target(metaclass=ABCMeta):
1212
@abstractmethod
1313
def get_feature_data(self, *args, **kwargs):
1414
"""
15-
Gets feature data for some input coordinate.
15+
Retrieve the feature data for some coordinate.
1616
"""
1717
raise NotImplementedError

selene/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
from .utils import *
1+
from .utils import initialize_logger, read_yaml_file
2+
from .config import load, load_path, instantiate
3+

0 commit comments

Comments
 (0)