5
5
from rhtorch .version import __version__
6
6
import socket
7
7
8
+
8
9
class UserConfig :
9
10
def __init__ (self , rootdir , arguments = None ):
10
11
self .rootdir = rootdir
11
12
self .config_file = self .is_path (arguments .config )
12
13
self .args = arguments
13
-
14
+
14
15
# load default configs
15
- default_config_file = Path (__file__ ).parent .joinpath ('default_config.yaml ' )
16
+ default_config_file = Path (__file__ ).parent .joinpath ('default.config ' )
16
17
with open (default_config_file ) as dcf :
17
18
self .default_params = yaml .load (dcf , Loader = yaml .Loader )
18
-
19
+
19
20
# load user config file
20
21
with open (self .config_file ) as cf :
21
22
self .hparams = yaml .load (cf , Loader = yaml .RoundTripLoader )
22
-
23
+
23
24
# merge the two dicts
24
25
self .merge_dicts ()
25
-
26
+
26
27
# sanity check on data_folder provided by user
27
28
self .data_path = self .is_path (self .hparams ['data_folder' ])
28
-
29
+
29
30
# make model name
30
31
self .fill_additional_info ()
31
32
self .create_model_name ()
32
-
33
+
33
34
def is_path (self , path ):
34
35
# check for path - assuming absolute path was given
35
36
filepath = Path (path )
36
37
if not filepath .exists ():
37
38
# assuming path was given relative to rootdir
38
39
filepath = self .rootdir .joinpath (filepath )
39
40
if not filepath .exists ():
40
- raise FileNotFoundError (f"{ path } not found. Define relative to project directory or as absolute path in config file/argument passing." )
41
-
41
+ raise FileNotFoundError (
42
+ f"{ path } not found. Define relative to project directory or as absolute path in config file/argument passing." )
43
+
42
44
return filepath
43
45
44
46
def merge_dicts (self ):
45
47
""" adds to the user_params dictionnary any missing key from the default params """
46
-
48
+
47
49
for key , value in self .default_params .items ():
48
50
# copy from default if value is not None/0/False and key not already in user config
49
51
if value and key not in self .hparams :
50
52
self .hparams [key ] = value
51
-
52
53
53
54
def fill_additional_info (self ):
54
55
# additional info from args and miscellaneous to save in config
@@ -58,21 +59,22 @@ def fill_additional_info(self):
58
59
self .hparams ['config_file' ] = str (self .config_file )
59
60
self .hparams ['k_fold' ] = self .args .kfold
60
61
self .hparams ['GPUs' ] = torch .cuda .device_count ()
61
- self .hparams ['global_batch_size' ] = self .hparams ['batch_size' ] * self .hparams ['GPUs' ]
62
+ self .hparams ['global_batch_size' ] = self .hparams ['batch_size' ] * \
63
+ self .hparams ['GPUs' ]
62
64
self .hparams ['rhtorch_version' ] = __version__
63
65
self .hparams ['hostname' ] = socket .gethostname ()
64
-
66
+
65
67
def create_model_name (self ):
66
-
68
+
67
69
data_shape = 'x' .join (map (str , self .hparams ['data_shape' ]))
68
70
base_name = f"{ self .hparams ['module' ]} _{ self .hparams ['version_name' ]} _{ self .hparams ['data_generator' ]} "
69
71
dat_name = f"bz{ self .hparams ['batch_size' ]} _{ data_shape } "
70
72
self .hparams ['model_name' ] = f"{ base_name } _{ dat_name } _k{ self .args .kfold } _e{ self .hparams ['epoch' ]} "
71
-
73
+
72
74
def save_copy (self , output_dir , append_timestamp = False ):
73
75
model_name = self .hparams ['model_name' ]
74
76
timestamp = f"_{ self .hparams ['build_date' ]} " if append_timestamp else ""
75
- save_config_file_name = f"config_{ model_name } { timestamp } "
77
+ save_config_file_name = f"config_{ model_name } { timestamp } "
76
78
config_file = output_dir .joinpath (save_config_file_name + ".yaml" )
77
79
self .hparams .yaml_set_start_comment (f'Config file for { model_name } ' )
78
80
with open (config_file , 'w' ) as file :
0 commit comments