|
1 |
| -#!/usr/bin/env python3 |
2 |
| -# -*- coding: utf-8 -*- |
3 | 1 | import ruamel.yaml as yaml
|
4 | 2 | from datetime import datetime
|
5 | 3 | from pathlib import Path
|
6 | 4 | import torch
|
7 | 5 | from rhtorch.version import __version__
|
8 | 6 | import socket
|
9 | 7 |
|
10 |
| -loss_map = {'MeanAbsoluteError': 'mae', |
11 |
| - 'MeanSquaredError': 'mse', |
12 |
| - 'huber_loss': 'huber', |
13 |
| - 'BCEWithLogitsLoss': 'BCE'} |
| 8 | +class UserConfig: |
| 9 | + def __init__(self, rootdir, arguments=None): |
| 10 | + self.rootdir = rootdir |
| 11 | + self.config_file = self.is_path(arguments.config) |
| 12 | + self.args = arguments |
| 13 | + |
| 14 | + # load default configs |
| 15 | + default_config_file = Path(__file__).parent.joinpath('default_config.yaml') |
| 16 | + with open(default_config_file) as dcf: |
| 17 | + self.default_params = yaml.load(dcf, Loader=yaml.Loader) |
| 18 | + |
| 19 | + # load user config file |
| 20 | + with open(self.config_file) as cf: |
| 21 | + self.hparams = yaml.load(cf, Loader=yaml.RoundTripLoader) |
| 22 | + |
| 23 | + # merge the two dicts |
| 24 | + self.merge_dicts() |
| 25 | + |
| 26 | + # sanity check on data_folder provided by user |
| 27 | + self.data_path = self.is_path(self.hparams['data_folder']) |
| 28 | + |
| 29 | + # make model name |
| 30 | + self.fill_additional_info() |
| 31 | + self.create_model_name() |
| 32 | + |
| 33 | + def is_path(self, path): |
| 34 | + # check for path - assuming absolute path was given |
| 35 | + filepath = Path(path) |
| 36 | + if not filepath.exists(): |
| 37 | + # assuming path was given relative to rootdir |
| 38 | + filepath = self.rootdir.joinpath(filepath) |
| 39 | + if not filepath.exists(): |
| 40 | + raise FileNotFoundError(f"{path} not found. Define relative to project directory or as absolute path in config file/argument passing.") |
| 41 | + |
| 42 | + return filepath |
14 | 43 |
|
| 44 | + def merge_dicts(self): |
| 45 | + """ adds to the user_params dictionnary any missing key from the default params """ |
| 46 | + |
| 47 | + for k, v in self.default_params.items(): |
| 48 | + if k not in self.hparams: |
| 49 | + self.hparams[k] = v |
| 50 | + ### TO DO - ENSURE NOT COPYING IRRELEVANT DATA e.g. GAN parameters if model is AE |
15 | 51 |
|
16 |
| -def load_model_config(rootdir, arguments): |
| 52 | + def fill_additional_info(self): |
| 53 | + # additional info from args and miscellaneous to save in config |
| 54 | + self.hparams['build_date'] = datetime.now().strftime("%Y%m%d-%H%M%S") |
| 55 | + self.hparams['project_dir'] = str(self.rootdir) |
| 56 | + self.hparams['data_folder'] = str(self.data_path) |
| 57 | + self.hparams['config_file'] = str(self.config_file) |
| 58 | + self.hparams['k_fold'] = self.args.kfold |
| 59 | + self.hparams['GPUs'] = torch.cuda.device_count() |
| 60 | + self.hparams['global_batch_size'] = self.hparams['batch_size'] * self.hparams['GPUs'] |
| 61 | + self.hparams['rhtorch_version'] = __version__ |
| 62 | + self.hparams['hostname'] = socket.gethostname() |
17 | 63 |
|
18 |
| - # check for config_file |
19 |
| - config_file = Path(arguments.config) |
20 |
| - if not config_file.exists(): |
21 |
| - config_file = rootdir.joinpath(config_file) |
22 |
| - if not config_file.exists(): |
23 |
| - raise FileNotFoundError("Config file not found. Define relative to project directory or as absolute path in config file") |
24 |
| - |
25 |
| - # read the config file |
26 |
| - with open(config_file) as file: |
27 |
| - config = yaml.load(file, Loader=yaml.RoundTripLoader) |
| 64 | + def create_model_name(self): |
28 | 65 |
|
29 |
| - data_shape = 'x'.join(map(str, config['data_shape'])) |
30 |
| - base_name = f"{config['module']}_{config['version_name']}_{config['data_generator']}" |
31 |
| - dat_name = f"bz{config['batch_size']}_{data_shape}" |
32 |
| - full_name = f"{base_name}_{dat_name}_k{arguments.kfold}_e{config['epoch']}" |
33 |
| - |
34 |
| - # check for data folder |
35 |
| - data_folder = Path(config['data_folder']) |
36 |
| - if not data_folder.exists(): |
37 |
| - # try relative to project dir - in this case overwrite config |
38 |
| - data_folder = rootdir.joinpath(config['data_folder']) |
39 |
| - if not data_folder.exists(): |
40 |
| - raise FileNotFoundError("Data path not found. Define relative to the project directory or as absolute path in config file") |
| 66 | + data_shape = 'x'.join(map(str, self.hparams['data_shape'])) |
| 67 | + base_name = f"{self.hparams['module']}_{self.hparams['version_name']}_{self.hparams['data_generator']}" |
| 68 | + dat_name = f"bz{self.hparams['batch_size']}_{data_shape}" |
| 69 | + self.hparams['model_name'] = f"{base_name}_{dat_name}_k{self.args.kfold}_e{self.hparams['epoch']}" |
41 | 70 |
|
42 |
| - # additional info from args and miscellaneous to save in config |
43 |
| - config['build_date'] = datetime.now().strftime("%Y-%m-%d %H.%M.%S") |
44 |
| - config['model_name'] = full_name |
45 |
| - config['project_dir'] = str(rootdir) |
46 |
| - config['data_folder'] = str(data_folder) |
47 |
| - config['config_file'] = str(config_file) |
48 |
| - config['k_fold'] = arguments.kfold |
49 |
| - if 'precision' not in config: |
50 |
| - config['precision'] = 32 |
51 |
| - config['GPUs'] = torch.cuda.device_count() |
52 |
| - config['global_batch_size'] = config['batch_size'] * config['GPUs'] |
53 |
| - config['rhtorch_version'] = __version__ |
54 |
| - config['hostname'] = socket.gethostname() |
55 |
| - if 'acc_grad_batches' not in config: |
56 |
| - config['acc_grad_batches'] = 1 |
57 |
| - |
58 |
| - return config |
59 |
| - |
60 |
| - |
61 |
| -def copy_model_config(path, config, append_timestamp=False): |
62 |
| - model_name = config['model_name'] |
63 |
| - if append_timestamp: |
64 |
| - timestamp = config['build_date'].replace(' ','_') |
65 |
| - config_file = path.joinpath(f"config_{model_name}_{timestamp}.yaml") |
66 |
| - else: |
67 |
| - config_file = path.joinpath(f"config_{model_name}.yaml") |
68 |
| - config.yaml_set_start_comment(f'Config file for {model_name}') |
69 |
| - with open(config_file, 'w') as file: |
70 |
| - yaml.dump(config, file, Dumper=yaml.RoundTripDumper) |
| 71 | + def save_copy(self, output_dir, append_timestamp=False): |
| 72 | + model_name = self.hparams['model_name'] |
| 73 | + timestamp = f"_{self.hparams['build_date']}" if append_timestamp else "" |
| 74 | + save_config_file_name = f"config_{model_name}{timestamp}" |
| 75 | + config_file = output_dir.joinpath(save_config_file_name + ".yaml") |
| 76 | + self.hparams.yaml_set_start_comment(f'Config file for {model_name}') |
| 77 | + with open(config_file, 'w') as file: |
| 78 | + yaml.dump(self.hparams, file, Dumper=yaml.RoundTripDumper) |
0 commit comments