@@ -349,9 +349,9 @@ def on_train_epoch_start(self):
349
349
# Reset training torchmetrics
350
350
self .task ._reset_torchmetrics ("train" )
351
351
352
- def training_epoch_end (self , outputs ):
352
+ def on_train_epoch_end (self ):
353
353
# Log training torchmetrics
354
- super ().training_epoch_end ( outputs )
354
+ super ().on_train_epoch_end ( )
355
355
self .log_dict (
356
356
{f"train/{ k } " : v for k , v in self .task .get_torchmetrics ("train" ).items ()},
357
357
on_step = False ,
@@ -367,9 +367,9 @@ def on_validation_epoch_start(self):
367
367
for name in self .val_loader_names :
368
368
self .task ._reset_torchmetrics (name )
369
369
370
- def validation_epoch_end (self , outputs ):
370
+ def on_validation_epoch_end (self ):
371
371
# Log all validation torchmetrics
372
- super ().validation_epoch_end ( outputs )
372
+ super ().on_validation_epoch_end ( )
373
373
for name in self .val_loader_names :
374
374
self .log_dict (
375
375
{f"{ name } /{ k } " : v for k , v in self .task .get_torchmetrics (name ).items ()},
@@ -386,9 +386,9 @@ def on_test_epoch_start(self):
386
386
for name in self .test_loader_names :
387
387
self .task ._reset_torchmetrics (name )
388
388
389
- def test_epoch_end (self , outputs ):
389
+ def on_test_epoch_end (self ):
390
390
# Log all test torchmetrics
391
- super ().test_epoch_end ( outputs )
391
+ super ().on_test_epoch_end ( )
392
392
for name in self .test_loader_names :
393
393
self .log_dict (
394
394
{f"{ name } /{ k } " : v for k , v in self .task .get_torchmetrics (name ).items ()},
@@ -411,7 +411,7 @@ def training_step(self, batch, batch_idx):
411
411
loss_epoch ,
412
412
on_step = True ,
413
413
on_epoch = False ,
414
- prog_bar = False ,
414
+ prog_bar = True ,
415
415
add_dataloader_idx = False ,
416
416
sync_dist = True ,
417
417
)
@@ -666,6 +666,23 @@ def create_trainer(config):
666
666
# Stage params are resolution and epochs, pretty print
667
667
print (f"\t Stage { i } : { e ['resolution' ]} @ { e ['epochs' ]} epochs" )
668
668
669
+ # Additional ModelCheckpoint callback for preemption
670
+ if config .tolerance .id is not None :
671
+ pass
672
+ # if 'model_checkpoint' in config.callbacks.keys():
673
+ # callback_args = config.callbacks['model_checkpoint']
674
+ # callback_args._name_ = 'model_checkpoint' # For the registry
675
+ # # Save last two checkpoints to be extra fault tolerant
676
+ # callback_args.save_top_k = 2
677
+ # callback_args.monitor = 'trainer/epoch'
678
+ # callback_args.mode = 'max'
679
+ # callback_args.save_last = False
680
+ # callback_args.filename = 'last'
681
+ # # callback_args.save_on_train_epoch_end = True # this is False for the other checkpoint callback
682
+ # ckpt_callback = utils.instantiate(registry.callbacks, callback_args)
683
+ # # ckpt_callback.CHECKPOINT_NAME_LAST = 'last_' # now we have two last checkpoints, last.ckpt and last_.ckpt
684
+ # callbacks.append(ckpt_callback)
685
+
669
686
trainer = pl .Trainer (
670
687
logger = logger ,
671
688
callbacks = callbacks ,
@@ -681,6 +698,31 @@ def train(config):
681
698
trainer = create_trainer (config )
682
699
model = SequenceLightningModule (config )
683
700
701
+ # Load pretrained_model if specified
702
+ if config .train .get ("pretrained_model_path" , None ) is not None :
703
+ # PTL style. Note, method returns a new model object, and need to pass config.
704
+ model = SequenceLightningModule .load_from_checkpoint (
705
+ config .train .pretrained_model_path ,
706
+ config = config ,
707
+ strict = config .train .pretrained_model_strict_load ,
708
+ )
709
+ print ("Loaded pretrained model from" , config .train .pretrained_model_path )
710
+
711
+ # Added by KS for pre-training
712
+ # [22-07-21 AG] refactored, untested
713
+ if config .train .get ("ignore_pretrained_layers" , False ):
714
+ pretrained_dict = pretrained_model .state_dict ()
715
+ model_dict = model .state_dict ()
716
+ for k , v in model_dict .items ():
717
+ for ignore_layer in config .train .ignore_pretrained_layers :
718
+ if ignore_layer in k :
719
+ pretrained_dict [k ] = v
720
+ model .load_state_dict (pretrained_dict )
721
+ if config .train .get ("pretrained_freeze_encoder" , False ):
722
+ for name , param in model .named_parameters ():
723
+ if not ("decoder" in name ): param .requires_grad = False
724
+
725
+
684
726
# Run initial validation epoch (useful for debugging, finetuning)
685
727
if config .train .validate_at_start :
686
728
print ("Running validation before training" )
@@ -693,6 +735,82 @@ def train(config):
693
735
if config .train .test :
694
736
trainer .test (model )
695
737
738
+
739
+
740
+ def preemption_setup (config ):
741
+ if config .tolerance .id is None :
742
+ return config
743
+
744
+ # Create path ./logdir/id/ to store information for resumption
745
+ resume_dir = os .path .join (get_original_cwd (), config .tolerance .logdir , str (config .tolerance .id ))
746
+
747
+ if os .path .exists (resume_dir ):
748
+ print (f"Resuming from { resume_dir } " )
749
+
750
+ # Load path to the last checkpoint
751
+ with open (os .path .join (resume_dir , "hydra.txt" ), "r" ) as f :
752
+ hydra_paths = list (f .readlines ())
753
+
754
+ # Look at the previous runs in reverse order
755
+ checkpoint_path = None
756
+ for hydra_path in reversed (hydra_paths ):
757
+ hydra_path = hydra_path .rstrip ('\n ' )
758
+
759
+ # Get the paths to the last.ckpt and last_.ckpt files
760
+ last_path = os .path .join (hydra_path , "checkpoints" , "last.ckpt" )
761
+ # last__path = os.path.join(hydra_path, "checkpoints", "last_.ckpt")
762
+ # last_exists, last__exists = os.path.exists(last_path), os.path.exists(last__path)
763
+
764
+ # if not last_exists or not last__exists:
765
+ # # This run doesn't have both checkpoints, so skip it
766
+ # print(f"\tSkipping {hydra_path}, not suitable for resuming (last_exists = {last_exists}, last__exists = {last__exists})")
767
+ # continue
768
+
769
+ # # Read timestamp when checkpoints were modified
770
+ # # We want to load the _earlier_ checkpoint, since that is guaranteed to be uncorrupted
771
+ # last_timestamp = os.path.getmtime(last_path)
772
+ # last__timestamp = os.path.getmtime(last__path)
773
+ # print("\t\tlast_timestamp =", last_timestamp)
774
+ # print("\t\tlast__timestamp =", last__timestamp)
775
+
776
+ # if last_timestamp < last__timestamp:
777
+ # checkpoint_path = last_path
778
+ # else:
779
+ # checkpoint_path = last__path
780
+ # checkpoint_path = last_path
781
+ # config.train.ckpt = checkpoint_path
782
+
783
+ if os .path .exists (last_path ):
784
+ print ("\t Found checkpoint at" , last_path )
785
+ config .train .ckpt = last_path
786
+ # HACK TODO
787
+ config .train .pretrained_model_path = None
788
+ config .train .pretrained_model_state_hook ._name_ = None
789
+ # config.train.pretrained_model_reinit_hook._name_ = None
790
+ break
791
+
792
+ # If we didn't find a checkpoint
793
+ if checkpoint_path is None :
794
+ print ("\t No suitable checkpoint found, starting from scratch" )
795
+
796
+ # Set wandb run id to resume
797
+ if os .path .exists (os .path .join (hydra_path , 'wandb' )):
798
+ run_info = [e for e in os .listdir (os .path .join (hydra_path , 'wandb' )) if e .startswith ('run-' )][0 ]
799
+ run_id = run_info .split ('-' )[- 1 ]
800
+ try :
801
+ config .wandb .id = run_id
802
+ except AttributeError :
803
+ pass
804
+
805
+ os .makedirs (resume_dir , exist_ok = True )
806
+
807
+ # Store path to Hydra output folder
808
+ with open (os .path .join (resume_dir , 'hydra.txt' ), 'a' ) as f :
809
+ f .write (os .getcwd () + '\n ' )
810
+
811
+ return config
812
+
813
+
696
814
@hydra .main (config_path = "configs" , config_name = "config.yaml" )
697
815
def main (config : OmegaConf ):
698
816
@@ -705,6 +823,8 @@ def main(config: OmegaConf):
705
823
# Pretty print config using Rich library
706
824
utils .train .print_config (config , resolve = True )
707
825
826
+ config = preemption_setup (config )
827
+
708
828
train (config )
709
829
710
830
0 commit comments