Skip to content

Commit 976c9e9

Browse files
committed
[Infra] Update to pytorch-lightning 2.0
1 parent fffbeee commit 976c9e9

File tree

10 files changed

+135
-18
lines changed

10 files changed

+135
-18
lines changed

configs/experiment/s4nd/convnext/convnext_timm_tiny_imagenet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ trainer:
2525
max_epochs: 310
2626
precision: 16
2727
devices: 1
28-
replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"} # only True if using RepeatAug
28+
use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"} # only True if using RepeatAug
2929
accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}
3030

3131
train:

configs/experiment/s4nd/convnext/convnext_timm_tiny_inflate3d_s4nd_hmdb.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ train:
4949
remove_test_loader_in_eval: true # null means we do use test loader
5050
global_batch_size: ${loader.batch_size} # effective batch size (handled with multiple gpus, and accumulate_grad_batches)
5151
pretrained_model_strict_load: False
52-
replace_sampler_ddp: False # ${eval:"${trainer.devices} > 1"}
52+
use_distributed_sampler: False # ${eval:"${trainer.devices} > 1"}
5353
pretrained_model_state_hook:
5454
_name_: convnext_timm_tiny_s4nd_2d_to_3d
5555

configs/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ trainer:
3131
max_epochs: 310
3232
precision: 16
3333
devices: 8
34-
replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
34+
use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
3535
accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}
3636

3737
train:

configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ trainer:
4646
max_epochs: 310
4747
precision: 16
4848
devices: 8
49-
replace_sampler_ddp: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
49+
use_distributed_sampler: ${eval:"${dataset.num_aug_repeats} == 0"} # only true if using RepeatAug
5050

5151
train:
5252
seed: 1112

configs/trainer/debug.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@ overfit_batches: 0
1717
limit_train_batches: 0.1
1818
limit_val_batches: 0.1
1919
limit_test_batches: 0.1
20-
track_grad_norm: -1
2120
terminate_on_nan: False

configs/trainer/default.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# See Docs for full flags and descriptions
22
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api
33
accelerator: gpu
4-
strategy: null
4+
strategy: auto
55
devices: 1
66
accumulate_grad_batches: 1 # Gradient accumulation every n batches
77
max_epochs: 200
@@ -11,4 +11,3 @@ log_every_n_steps: 10
1111
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
1212
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
1313
enable_model_summary: false # Can turn on if RichModelSummary is disabled
14-
track_grad_norm: -1 # Set to 2 to track norms of gradients

configs/trainer/lm.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@ gradient_clip_val: null # Gradient clipping
77
log_every_n_steps: 10
88
precision: 16
99
enable_model_summary: false # Can turn on if RichModelSummary is disabled
10-
track_grad_norm: -1 # Set to 2 to track norms of gradients
1110
limit_train_batches: 1.0
1211
limit_val_batches: 1.0
1312
# We use the dataloader from Transformer-XL to ensure adjacent minibatches
1413
# are from text that are next to each other.
1514
# So that dataloader has to deal with DDP, and we don't want PL to handle
1615
# that.
17-
replace_sampler_ddp: False
16+
use_distributed_sampler: False

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ rich
88
torchtext
99
lit # Getting installation errors with torch 2.0 if this isn't installed
1010
# torchvision
11-
pytorch-lightning==1.9.3
11+
pytorch-lightning==2.0.4
1212
hydra-core
1313
omegaconf
1414
wandb

src/callbacks/norms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def on_after_training_step(self, batch, batch_idx, trainer: pl.Trainer, pl_modul
2626

2727
def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
2828
# example to inspect gradient information in tensorboard
29-
if OmegaConf.select(trainer.hparams, 'trainer.track_grad_norms'): # TODO dot notation should work with omegaconf?
29+
if OmegaConf.select(trainer.hparams, 'train.track_grad_norms'): # TODO dot notation should work with omegaconf?
3030
norms = {}
3131
for name, p in pl_module.named_parameters():
3232
if p.grad is None:

train.py

Lines changed: 127 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,9 @@ def on_train_epoch_start(self):
349349
# Reset training torchmetrics
350350
self.task._reset_torchmetrics("train")
351351

352-
def training_epoch_end(self, outputs):
352+
def on_train_epoch_end(self):
353353
# Log training torchmetrics
354-
super().training_epoch_end(outputs)
354+
super().on_train_epoch_end()
355355
self.log_dict(
356356
{f"train/{k}": v for k, v in self.task.get_torchmetrics("train").items()},
357357
on_step=False,
@@ -367,9 +367,9 @@ def on_validation_epoch_start(self):
367367
for name in self.val_loader_names:
368368
self.task._reset_torchmetrics(name)
369369

370-
def validation_epoch_end(self, outputs):
370+
def on_validation_epoch_end(self):
371371
# Log all validation torchmetrics
372-
super().validation_epoch_end(outputs)
372+
super().on_validation_epoch_end()
373373
for name in self.val_loader_names:
374374
self.log_dict(
375375
{f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
@@ -386,9 +386,9 @@ def on_test_epoch_start(self):
386386
for name in self.test_loader_names:
387387
self.task._reset_torchmetrics(name)
388388

389-
def test_epoch_end(self, outputs):
389+
def on_test_epoch_end(self):
390390
# Log all test torchmetrics
391-
super().test_epoch_end(outputs)
391+
super().on_test_epoch_end()
392392
for name in self.test_loader_names:
393393
self.log_dict(
394394
{f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
@@ -411,7 +411,7 @@ def training_step(self, batch, batch_idx):
411411
loss_epoch,
412412
on_step=True,
413413
on_epoch=False,
414-
prog_bar=False,
414+
prog_bar=True,
415415
add_dataloader_idx=False,
416416
sync_dist=True,
417417
)
@@ -666,6 +666,23 @@ def create_trainer(config):
666666
# Stage params are resolution and epochs, pretty print
667667
print(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs")
668668

669+
# Additional ModelCheckpoint callback for preemption
670+
if config.tolerance.id is not None:
671+
pass
672+
# if 'model_checkpoint' in config.callbacks.keys():
673+
# callback_args = config.callbacks['model_checkpoint']
674+
# callback_args._name_ = 'model_checkpoint' # For the registry
675+
# # Save last two checkpoints to be extra fault tolerant
676+
# callback_args.save_top_k = 2
677+
# callback_args.monitor = 'trainer/epoch'
678+
# callback_args.mode = 'max'
679+
# callback_args.save_last = False
680+
# callback_args.filename = 'last'
681+
# # callback_args.save_on_train_epoch_end = True # this is False for the other checkpoint callback
682+
# ckpt_callback = utils.instantiate(registry.callbacks, callback_args)
683+
# # ckpt_callback.CHECKPOINT_NAME_LAST = 'last_' # now we have two last checkpoints, last.ckpt and last_.ckpt
684+
# callbacks.append(ckpt_callback)
685+
669686
trainer = pl.Trainer(
670687
logger=logger,
671688
callbacks=callbacks,
@@ -681,6 +698,31 @@ def train(config):
681698
trainer = create_trainer(config)
682699
model = SequenceLightningModule(config)
683700

701+
# Load pretrained_model if specified
702+
if config.train.get("pretrained_model_path", None) is not None:
703+
# PTL style. Note, method returns a new model object, and need to pass config.
704+
model = SequenceLightningModule.load_from_checkpoint(
705+
config.train.pretrained_model_path,
706+
config=config,
707+
strict=config.train.pretrained_model_strict_load,
708+
)
709+
print("Loaded pretrained model from", config.train.pretrained_model_path)
710+
711+
# Added by KS for pre-training
712+
# [22-07-21 AG] refactored, untested
713+
if config.train.get("ignore_pretrained_layers", False):
714+
pretrained_dict = pretrained_model.state_dict()
715+
model_dict = model.state_dict()
716+
for k, v in model_dict.items():
717+
for ignore_layer in config.train.ignore_pretrained_layers:
718+
if ignore_layer in k:
719+
pretrained_dict[k] = v
720+
model.load_state_dict(pretrained_dict)
721+
if config.train.get("pretrained_freeze_encoder", False):
722+
for name, param in model.named_parameters():
723+
if not("decoder" in name): param.requires_grad = False
724+
725+
684726
# Run initial validation epoch (useful for debugging, finetuning)
685727
if config.train.validate_at_start:
686728
print("Running validation before training")
@@ -693,6 +735,82 @@ def train(config):
693735
if config.train.test:
694736
trainer.test(model)
695737

738+
739+
740+
def preemption_setup(config):
741+
if config.tolerance.id is None:
742+
return config
743+
744+
# Create path ./logdir/id/ to store information for resumption
745+
resume_dir = os.path.join(get_original_cwd(), config.tolerance.logdir, str(config.tolerance.id))
746+
747+
if os.path.exists(resume_dir):
748+
print(f"Resuming from {resume_dir}")
749+
750+
# Load path to the last checkpoint
751+
with open(os.path.join(resume_dir, "hydra.txt"), "r") as f:
752+
hydra_paths = list(f.readlines())
753+
754+
# Look at the previous runs in reverse order
755+
checkpoint_path = None
756+
for hydra_path in reversed(hydra_paths):
757+
hydra_path = hydra_path.rstrip('\n')
758+
759+
# Get the paths to the last.ckpt and last_.ckpt files
760+
last_path = os.path.join(hydra_path, "checkpoints", "last.ckpt")
761+
# last__path = os.path.join(hydra_path, "checkpoints", "last_.ckpt")
762+
# last_exists, last__exists = os.path.exists(last_path), os.path.exists(last__path)
763+
764+
# if not last_exists or not last__exists:
765+
# # This run doesn't have both checkpoints, so skip it
766+
# print(f"\tSkipping {hydra_path}, not suitable for resuming (last_exists = {last_exists}, last__exists = {last__exists})")
767+
# continue
768+
769+
# # Read timestamp when checkpoints were modified
770+
# # We want to load the _earlier_ checkpoint, since that is guaranteed to be uncorrupted
771+
# last_timestamp = os.path.getmtime(last_path)
772+
# last__timestamp = os.path.getmtime(last__path)
773+
# print("\t\tlast_timestamp =", last_timestamp)
774+
# print("\t\tlast__timestamp =", last__timestamp)
775+
776+
# if last_timestamp < last__timestamp:
777+
# checkpoint_path = last_path
778+
# else:
779+
# checkpoint_path = last__path
780+
# checkpoint_path = last_path
781+
# config.train.ckpt = checkpoint_path
782+
783+
if os.path.exists(last_path):
784+
print("\tFound checkpoint at", last_path)
785+
config.train.ckpt = last_path
786+
# HACK TODO
787+
config.train.pretrained_model_path = None
788+
config.train.pretrained_model_state_hook._name_ = None
789+
# config.train.pretrained_model_reinit_hook._name_ = None
790+
break
791+
792+
# If we didn't find a checkpoint
793+
if checkpoint_path is None:
794+
print("\tNo suitable checkpoint found, starting from scratch")
795+
796+
# Set wandb run id to resume
797+
if os.path.exists(os.path.join(hydra_path, 'wandb')):
798+
run_info = [e for e in os.listdir(os.path.join(hydra_path, 'wandb')) if e.startswith('run-')][0]
799+
run_id = run_info.split('-')[-1]
800+
try:
801+
config.wandb.id = run_id
802+
except AttributeError:
803+
pass
804+
805+
os.makedirs(resume_dir, exist_ok=True)
806+
807+
# Store path to Hydra output folder
808+
with open(os.path.join(resume_dir, 'hydra.txt'), 'a') as f:
809+
f.write(os.getcwd() + '\n')
810+
811+
return config
812+
813+
696814
@hydra.main(config_path="configs", config_name="config.yaml")
697815
def main(config: OmegaConf):
698816

@@ -705,6 +823,8 @@ def main(config: OmegaConf):
705823
# Pretty print config using Rich library
706824
utils.train.print_config(config, resolve=True)
707825

826+
config = preemption_setup(config)
827+
708828
train(config)
709829

710830

0 commit comments

Comments
 (0)