Skip to content

Commit 822d3a7

Browse files
authored
Merge pull request #1 from CAAI/versioning
Version control
2 parents cafba61 + f35b6ec commit 822d3a7

File tree

4 files changed

+59
-10
lines changed

4 files changed

+59
-10
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,22 @@ Install only to your user. Go to your virtual environment. Run:
77
```
88
git clone https://github.com/CAAI/rh-torch.git && cd rh-torch
99
pip install .
10-
```
10+
```
11+
12+
## HOW TO CONTRIBUTE
13+
14+
Create your edits in a different branch, decicated to a few specific things. We prefer many minor edits over one huge. Once everything is well-documented and tested, perform a pull request for your edits to be made available in the main branch. See steps here:
15+
```
16+
git branch awesome-addition
17+
git checkout awesome-addition
18+
# do your changes
19+
git commit -a -m 'your changes'
20+
git push
21+
gh pr create --title "your title" --body "longer description of what you did"
22+
```
23+
If you wish, you can add ```--assignee <github_name>``` to ping specific persons for looking at your pull request.
24+
25+
One someone accepted your pull request (after reviewing the changes), it will be part of the main branch.
26+
27+
### Important before accepting pull requests
28+
Before you accept a pull request, please update the version.py with an incremented number and a description of what has changed. This number is logged as part of the training config file that is autogenerated.

rhtorch/config_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import datetime
55
from pathlib import Path
66
import torch
7+
from rhtorch.version import __version__
8+
import socket
79

810
loss_map = {'MeanAbsoluteError': 'mae',
911
'MeanSquaredError': 'mse',
@@ -24,10 +26,9 @@ def load_model_config(rootdir, arguments):
2426
with open(config_file) as file:
2527
config = yaml.load(file, Loader=yaml.RoundTripLoader)
2628

27-
batch_size = config['batch_size'] * torch.cuda.device_count()
2829
data_shape = 'x'.join(map(str, config['data_shape']))
2930
base_name = f"{config['module']}_{config['version_name']}_{config['data_generator']}"
30-
dat_name = f"bz{batch_size}_{data_shape}"
31+
dat_name = f"bz{config['batch_size']}_{data_shape}"
3132
full_name = f"{base_name}_{dat_name}_k{arguments.kfold}_e{config['epoch']}"
3233

3334
# check for data folder
@@ -39,22 +40,31 @@ def load_model_config(rootdir, arguments):
3940
raise FileNotFoundError("Data path not found. Define relative to the project directory or as absolute path in config file")
4041

4142
# additional info from args and miscellaneous to save in config
42-
config['build date'] = datetime.now().strftime("%Y-%m-%d %H.%M.%S")
43+
config['build_date'] = datetime.now().strftime("%Y-%m-%d %H.%M.%S")
4344
config['model_name'] = full_name
4445
config['project_dir'] = str(rootdir)
4546
config['data_folder'] = str(data_folder)
4647
config['config_file'] = str(config_file)
4748
config['k_fold'] = arguments.kfold
48-
config['precision'] = arguments.precision
49+
if 'precision' not in config:
50+
config['precision'] = 32
51+
config['GPUs'] = torch.cuda.device_count()
52+
config['global_batch_size'] = config['batch_size'] * config['GPUs']
53+
config['rhtorch_version'] = __version__
54+
config['hostname'] = socket.gethostname()
4955
if 'acc_grad_batches' not in config:
5056
config['acc_grad_batches'] = 1
5157

5258
return config
5359

5460

55-
def copy_model_config(path, config):
61+
def copy_model_config(path, config, append_timestamp=False):
5662
model_name = config['model_name']
57-
config_file = path.joinpath(f"config_{model_name}.yaml")
63+
if append_timestamp:
64+
timestamp = config['build_date'].replace(' ','_')
65+
config_file = path.joinpath(f"config_{model_name}_{timestamp}.yaml")
66+
else:
67+
config_file = path.joinpath(f"config_{model_name}.yaml")
5868
config.yaml_set_start_comment(f'Config file for {model_name}')
5969
with open(config_file, 'w') as file:
6070
yaml.dump(config, file, Dumper=yaml.RoundTripDumper)

rhtorch/torch_training.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from rhtorch.callbacks import plotting
1515
from rhtorch.config_utils import load_model_config, copy_model_config
1616

17-
1817
def main():
1918
import argparse
2019

@@ -25,7 +24,6 @@ def main():
2524
parser.add_argument("-c", "--config", help="Config file else than 'config.yaml' in project directory (input dir)", type=str, default='config.yaml')
2625
parser.add_argument("-k", "--kfold", help="K-value for selecting train/test split subset. Default k=0", type=int, default=0)
2726
parser.add_argument("-t", "--test", help="Test run for 1 patient", action="store_true", default=False)
28-
# parser.add_argument("-p", "--precision", help="Torch precision. Default 32", type=int, default=32)
2927

3028
args = parser.parse_args()
3129
project_dir = Path(args.input)
@@ -117,6 +115,11 @@ def main():
117115
)
118116
callbacks.append(checkpoint_callback)
119117

118+
# Save the config prior to training the model - one for each time the script is started
119+
if not is_test:
120+
copy_model_config(model_path, configs, append_timestamp=True)
121+
print("Saved config prior to model training")
122+
120123
# set the trainer and fit
121124
trainer = pl.Trainer(max_epochs=configs['epoch'],
122125
logger=wandb_logger,
@@ -140,7 +143,7 @@ def main():
140143
torch.save(model.state_dict(), output_file)
141144
copy_model_config(model_path, configs)
142145
print("Saved model and config file to disk")
143-
146+
144147

145148
if __name__ == "__main__":
146149
main()

rhtorch/version.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Fri May 21 08:58:01 2021
5+
6+
@author: claes
7+
"""
8+
9+
__version__ = '0.0.3'
10+
11+
"""
12+
13+
VERSIONING (UPDATED WHEN PR ARE MERGED INTO MASTER BRANCH)
14+
0.0.1 # Added repository (CL 18-05-2021)
15+
0.0.2 # Cleaned up main, moved to torchmetrics in modules (RD 20-05-2021)
16+
0.0.3 # Added version control to config-logfiles
17+
18+
"""

0 commit comments

Comments
 (0)