3
3
import logging
4
4
import argparse
5
5
import warnings
6
+ import shutil
7
+ import difflib
8
+ import yaml
6
9
7
10
# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch.
8
11
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
29
32
root = "./" ,
30
33
tensorboard = False ,
31
34
wandb = False ,
35
+ wandb_watch = False ,
36
+ wandb_watch_kwargs = {},
32
37
model_builders = [
33
38
"SimpleIrrepsConfig" ,
34
39
"EnergyModel" ,
46
51
equivariance_test = False ,
47
52
grad_anomaly_mode = False ,
48
53
gpu_oom_offload = False ,
49
- append = False ,
54
+ append = True ,
50
55
warn_unused = False ,
51
56
_jit_bailout_depth = 2 , # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
52
57
# Quote from eelison in PyTorch slack:
68
73
69
74
70
75
def main (args = None , running_as_script : bool = True ):
71
- config = parse_command_line (args )
76
+ config , path_to_config , override_options = parse_command_line (args )
72
77
73
78
if running_as_script :
74
79
set_up_script_logger (config .get ("log" , None ), config .verbose )
75
80
76
- found_restart_file = exists (f"{ config .root } /{ config .run_name } /trainer.pth" )
81
+ train_dir = f"{ config .root } /{ config .run_name } "
82
+ found_restart_file = exists (f"{ train_dir } /trainer.pth" )
77
83
if found_restart_file and not config .append :
78
84
raise RuntimeError (
79
- f"Training instance exists at { config . root } / { config . run_name } ; "
85
+ f"Training instance exists at { train_dir } ; "
80
86
"either set append to True or use a different root or runname"
81
87
)
82
- elif not found_restart_file and isdir (f" { config . root } / { config . run_name } " ):
88
+ elif not found_restart_file and isdir (train_dir ):
83
89
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
84
90
# first training epoch (usually due to memory):
85
91
warnings .warn (
86
- f"Previous run folder at { config . root } / { config . run_name } exists, but a saved model "
92
+ f"Previous run folder at { train_dir } exists, but a saved model "
87
93
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
88
94
f"be started."
89
95
)
90
- rmtree (f" { config . root } / { config . run_name } " )
96
+ rmtree (train_dir )
91
97
92
- # for fresh new train
93
- if not found_restart_file :
98
+ if not found_restart_file : # fresh start
99
+ # update config with override parameters for setting up train-dir
100
+ config .update (override_options )
94
101
trainer = fresh_start (config )
95
- else :
96
- trainer = restart (config )
102
+ # copy original config to training directory
103
+ shutil .copyfile (path_to_config , f"{ train_dir } /original_config.yaml" )
104
+ else : # restart
105
+ # perform string matching for original config and restart config
106
+ # throw error if they are different
107
+ with (
108
+ open (f"{ train_dir } /original_config.yaml" ) as orig_f ,
109
+ open (path_to_config ) as current_f ,
110
+ ):
111
+ diffs = [
112
+ x
113
+ for x in difflib .Differ ().compare (
114
+ orig_f .readlines (), current_f .readlines ()
115
+ )
116
+ if x [0 ] in ("+" , "-" )
117
+ ]
118
+ if diffs :
119
+ raise RuntimeError (
120
+ f"Config { path_to_config } used for restart differs from original config for training run in { train_dir } .\n "
121
+ + "The following differences were found:\n \n "
122
+ + "" .join (diffs )
123
+ + "\n "
124
+ + "If you intend to override the original config parameters, use the --override flag. For example, use\n "
125
+ + f'`nequip-train { path_to_config } --override "max_epochs: 42"`\n '
126
+ + 'on the command line to override the config parameter "max_epochs"\n '
127
+ + "BE WARNED that use of the --override flag is not protected by consistency checks performed by NequIP."
128
+ )
129
+ else :
130
+ trainer = restart (config , override_options )
97
131
98
132
# Train
99
133
trainer .save ()
@@ -157,6 +191,12 @@ def parse_command_line(args=None):
157
191
help = "Warn instead of error when the config contains unused keys" ,
158
192
action = "store_true" ,
159
193
)
194
+ parser .add_argument (
195
+ "--override" ,
196
+ help = "Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option." ,
197
+ type = str ,
198
+ default = None ,
199
+ )
160
200
args = parser .parse_args (args = args )
161
201
162
202
config = Config .from_file (args .config , defaults = default_config )
@@ -169,10 +209,26 @@ def parse_command_line(args=None):
169
209
):
170
210
config [flag ] = getattr (args , flag ) or config [flag ]
171
211
172
- return config
212
+ # Set override options before _set_global_options so that things like allow_tf32 are correctly handled
213
+ if args .override is not None :
214
+ override_options = yaml .load (args .override , Loader = yaml .Loader )
215
+ assert isinstance (
216
+ override_options , dict
217
+ ), "--override's YAML string must define a dictionary of top-level options"
218
+ overridden_keys = set (config .keys ()).intersection (override_options .keys ())
219
+ set_keys = set (override_options .keys ()) - set (overridden_keys )
220
+ logging .info (
221
+ f"--override: overrode keys { list (overridden_keys )} and set new keys { list (set_keys )} "
222
+ )
223
+ del overridden_keys , set_keys
224
+ else :
225
+ override_options = {}
226
+
227
+ return config , args .config , override_options
173
228
174
229
175
230
def fresh_start (config ):
231
+
176
232
# we use add_to_config cause it's a fresh start and need to record it
177
233
check_code_version (config , add_to_config = True )
178
234
_set_global_options (config )
@@ -267,7 +323,7 @@ def _unused_check():
267
323
return trainer
268
324
269
325
270
- def restart (config ):
326
+ def restart (config , override_options ):
271
327
# load the dictionary
272
328
restart_file = f"{ config .root } /{ config .run_name } /trainer.pth"
273
329
dictionary = load_file (
@@ -276,20 +332,6 @@ def restart(config):
276
332
enforced_format = "torch" ,
277
333
)
278
334
279
- # compare dictionary to config and update stop condition related arguments
280
- for k in config .keys ():
281
- if config [k ] != dictionary .get (k , "" ):
282
- if k == "max_epochs" :
283
- dictionary [k ] = config [k ]
284
- logging .info (f'Update "{ k } " to { dictionary [k ]} ' )
285
- elif k .startswith ("early_stop" ):
286
- dictionary [k ] = config [k ]
287
- logging .info (f'Update "{ k } " to { dictionary [k ]} ' )
288
- elif isinstance (config [k ], type (dictionary .get (k , "" ))):
289
- raise ValueError (
290
- f'Key "{ k } " is different in config and the result trainer.pth file. Please double check'
291
- )
292
-
293
335
# note, "trainer.pth"/dictionary also store code versions,
294
336
# which will not be stored in config and thus not checked here
295
337
check_code_version (config )
@@ -299,6 +341,10 @@ def restart(config):
299
341
300
342
config = Config (dictionary , exclude_keys = ["state_dict" , "progress" ])
301
343
344
+ # override configs loaded from save
345
+ dictionary .update (override_options )
346
+ config .update (override_options )
347
+
302
348
# dtype, etc.
303
349
_set_global_options (config )
304
350
0 commit comments