Skip to content

Commit a164f78

Browse files
committed
fix #8
1 parent 4e394b0 commit a164f78

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

espfit/app/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def report_loss(self, epoch, loss_dict):
349349

350350
log_file_path = os.path.join(self.output_directory_path, 'reporter.log')
351351
df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T
352-
df_new = df_new.mul(100) # Multiple each loss component by 100
352+
df_new = df_new.mul(100) # Multiple each loss component by 100. Is this large enough?
353353
df_new.insert(0, 'epoch', epoch)
354354

355355
if os.path.exists(log_file_path):
@@ -536,7 +536,6 @@ def train_sampler(self, sampler_patience=800, neff_threshold=0.2, sampler_weight
536536
sampler = SamplerReweight.samplers[sampler_index]
537537
loss += sampler_loss * sampler_weight
538538
loss_dict[f'{sampler.target_name}'] = sampler_loss.item()
539-
540539
loss.backward()
541540
loss_dict['neff'] = neff_min
542541

@@ -626,7 +625,7 @@ def _save_local_model(self, epoch, net_copy):
626625
_logger.info(f'Save ckpt{epoch}.pt as temporary espaloma model (net.pt)')
627626
self._save_checkpoint(epoch)
628627
local_model = os.path.join(self.output_directory_path, f"ckpt{epoch}.pt")
629-
self.save_model(net=net_copy, best_model=local_model, model_name=f"net.pt", output_directory_path=self.output_directory_path)
628+
self.save_model(net=net_copy, checkpoint_file=local_model, output_model=f"net.pt", output_directory_path=self.output_directory_path)
630629

631630

632631
def _setup_local_samplers(self, epoch, net_copy, debug):

0 commit comments

Comments
 (0)