Skip to content

Commit dc21dfe

Browse files
wanghan-iapcmHan Wang
andauthored
Skip train when the init model is provided (#116)
At iteration 0, it is not necessary to train the model if a init_model is provided. Co-authored-by: Han Wang <[email protected]>
1 parent 21bad63 commit dc21dfe

File tree

4 files changed

+48
-20
lines changed

4 files changed

+48
-20
lines changed

dpgen2/entrypoint/args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ def dp_train_args():
1919
doc_numb_models = "Number of models trained for evaluating the model deviation"
2020
doc_config = "Configuration of training"
2121
doc_template_script = "File names of the template training script. It can be a `List[Dict]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `Dict`, the models share the same template training script. "
22+
doc_init_models_paths = "the paths to initial models"
2223

2324
return [
2425
Argument("config", dict, RunDPTrain.training_args(), optional=True, default=RunDPTrain.normalize_config({}), doc=doc_numb_models),
2526
Argument("numb_models", int, optional=True, default=4, doc=doc_numb_models),
2627
Argument("template_script", [list,str], optional=False, doc=doc_template_script),
28+
Argument("init_models_paths", list, optional=True, doc=doc_init_models_paths, alias=['training_iter0_model_path']),
2729
]
2830

2931
def variant_train():

dpgen2/entrypoint/submit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def workflow_concurrent_learning(
303303
collect_data_config = normalize_step_dict(config.get('collect_data_config', default_config)) if old_style else config['step_configs']['collect_data_config']
304304
cl_step_config = normalize_step_dict(config.get('cl_step_config', default_config)) if old_style else config['step_configs']['cl_step_config']
305305
upload_python_packages = config.get('upload_python_packages', None)
306-
init_models_paths = config.get('training_iter0_model_path', None) if old_style else config['train'].get('training_iter0_model_path', None)
306+
init_models_paths = config.get('training_iter0_model_path', None) if old_style else config['train'].get('init_models_paths', None)
307307
if upload_python_packages is not None and isinstance(upload_python_packages, str):
308308
upload_python_packages = [upload_python_packages]
309309
if upload_python_packages is not None:

dpgen2/op/run_dp_train.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os, json, dpdata, glob
1+
import os, json, dpdata, glob, shutil
22
from pathlib import Path
33
from dpgen2.utils.run_command import run_command
44
from dpgen2.utils.chdir import set_directory
@@ -125,6 +125,14 @@ def execute(
125125
train_dict = RunDPTrain.write_other_to_input_script(
126126
train_dict, config, do_init_model, major_version)
127127

128+
if RunDPTrain.skip_training(work_dir, train_dict, init_model, iter_data):
129+
return OPIO({
130+
"script" : work_dir / train_script_name,
131+
"model" : work_dir / "frozen_model.pb",
132+
"lcurve" : work_dir / "lcurve.out",
133+
"log" : work_dir / "train.log",
134+
})
135+
128136
with set_directory(work_dir):
129137
# open log
130138
fplog = open('train.log', 'w')
@@ -224,6 +232,30 @@ def write_other_to_input_script(
224232
raise RuntimeError('unsupported DeePMD-kit major version', major_version)
225233
return odict
226234

235+
@staticmethod
236+
def skip_training(
237+
work_dir,
238+
train_dict,
239+
init_model,
240+
iter_data,
241+
):
242+
# we have init model and no iter data, skip training
243+
if (init_model is not None) and \
244+
(iter_data is None or len(iter_data) == 0) :
245+
with set_directory(work_dir):
246+
with open(train_script_name, 'w') as fp:
247+
json.dump(train_dict, fp, indent=4)
248+
Path('train.log').write_text(
249+
f'We have init model {init_model} and '
250+
f'no iteration training data. '
251+
f'The training is skipped.\n'
252+
)
253+
Path('lcurve.out').touch()
254+
shutil.copy(init_model, 'frozen_model.pb')
255+
return True
256+
else:
257+
return False
258+
227259
@staticmethod
228260
def decide_init_model(
229261
config,

tests/op/test_run_dp_train.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def setUp(self):
577577

578578

579579
def tearDown(self):
580-
for ii in ['init', self.task_path, self.task_name, 'foo' ]:
580+
for ii in ['init', self.task_path, self.task_name, 'foo']:
581581
if Path(ii).exists():
582582
shutil.rmtree(str(ii))
583583

@@ -592,10 +592,7 @@ def test_update_input_dict_v2_empty_list(self):
592592
self.assertDictEqual(odict, self.expected_odict_v2)
593593

594594

595-
@patch('dpgen2.op.run_dp_train.run_command')
596-
def test_exec_v2_empty_list(self, mocked_run):
597-
mocked_run.side_effect = [ (0, 'foo\n', ''), (0, 'bar\n', '') ]
598-
595+
def test_exec_v2_empty_list(self):
599596
config = self.config.copy()
600597
config['init_model_policy'] = 'no'
601598

@@ -606,6 +603,9 @@ def test_exec_v2_empty_list(self, mocked_run):
606603
task_name = self.task_name
607604
work_dir = Path(task_name)
608605

606+
self.init_model = self.init_model.absolute()
607+
self.init_model.write_text('this is init model')
608+
609609
ptrain = RunDPTrain()
610610
out = ptrain.execute(
611611
OPIO({
@@ -621,26 +621,20 @@ def test_exec_v2_empty_list(self, mocked_run):
621621
self.assertEqual(out['model'], work_dir/'frozen_model.pb')
622622
self.assertEqual(out['lcurve'], work_dir/'lcurve.out')
623623
self.assertEqual(out['log'], work_dir/'train.log')
624-
625-
calls = [
626-
call(['dp', 'train', train_script_name]),
627-
call(['dp', 'freeze', '-o', 'frozen_model.pb']),
628-
]
629-
mocked_run.assert_has_calls(calls)
630-
624+
631625
self.assertTrue(work_dir.is_dir())
632626
self.assertTrue(out['log'].is_file())
633627
self.assertEqual(out['log'].read_text(),
634-
'#=================== train std out ===================\n'
635-
'foo\n'
636-
'#=================== train std err ===================\n'
637-
'#=================== freeze std out ===================\n'
638-
'bar\n'
639-
'#=================== freeze std err ===================\n'
628+
f'We have init model {self.init_model} and '
629+
f'no iteration training data. '
630+
f'The training is skipped.\n'
640631
)
641632
with open(out['script']) as fp:
642633
jdata = json.load(fp)
643634
self.assertDictEqual(jdata, self.expected_odict_v2)
635+
self.assertEqual(Path(out['model']).read_text(), "this is init model")
636+
637+
os.remove(self.init_model)
644638

645639

646640
@patch('dpgen2.op.run_dp_train.run_command')

0 commit comments

Comments
 (0)