diff --git a/configs/mnist_pl.yaml b/configs/mnist_pl.yaml new file mode 100644 index 0000000..496f13b --- /dev/null +++ b/configs/mnist_pl.yaml @@ -0,0 +1,29 @@ +module: basic +output_dir: $SCRATCH/pytorch-examples/output +name: mnist_pl + +data: + name: mnist + data_path: $SCRATCH/pytorch-examples/mnist/data + num_workers: 8 + batch_size: 64 + +model: + name: cnn + input_shape: [1, 28, 28] + conv_sizes: [8, 16] + fc_sizes: [32] + output_size: 10 + +loss: + name: CrossEntropyLoss + +optimizer: + name: Adam + lr: 0.001 + +metrics: + acc: Accuracy + +trainer: + max_epochs: 4 diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..60ccf19 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,13 @@ +""" +Python module for holding our PyTorch Lightning modules. + +Trainers here inherit from the BaseTrainer and implement the logic for +constructing the model as well as training and evaluation. +""" + +import importlib + +def get_module(name, config): + """Factory function for constructing a Trainer""" + module = importlib.import_module('.' + name, 'modules') + return module.get_module(config) diff --git a/modules/basic.py b/modules/basic.py new file mode 100644 index 0000000..c1235c9 --- /dev/null +++ b/modules/basic.py @@ -0,0 +1,50 @@ +# System +import logging + +# Externals +import torch +import pytorch_lightning as pl + +# Locals +from models import get_model + + +class BasicModule(pl.LightningModule): + """PL Module for basic single-model examples""" + + def __init__(self, config): + super().__init__() + self.config = config + + # Construct the model + self.model = get_model(**config['model']) + + # Construct the loss function + loss_config = config['loss'] + Loss = getattr(torch.nn, loss_config.pop('name')) + self.loss_func = Loss(**loss_config) + + def configure_optimizers(self): + logging.info('configure_optimizers') + optimizer_config = self.config['optimizer'] + Optim = getattr(torch.optim, optimizer_config.pop('name')) + return Optim(self.model.parameters(), **optimizer_config) + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + batch_input, batch_target = batch + batch_output = self.model(batch_input) + batch_loss = self.loss_func(batch_output, batch_target) + self.log('train_loss', batch_loss) + return batch_loss + + def validation_step(self, batch, batch_idx): + batch_input, batch_target = batch + batch_output = self.model(batch_input) + batch_loss = self.loss_func(batch_output, batch_target) + self.log("valid_loss", batch_loss) + +def get_module(config): + return BasicModule(config) diff --git a/scripts/train_pl_cgpu.sh b/scripts/train_pl_cgpu.sh new file mode 100755 index 0000000..29f7c47 --- /dev/null +++ b/scripts/train_pl_cgpu.sh @@ -0,0 +1,17 @@ +#!/bin/bash +#SBATCH -C gpu +#SBATCH --cpus-per-task=10 +#SBATCH --ntasks-per-node=8 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=none +#SBATCH --time 30 +#SBATCH -J train-cgpu +#SBATCH -o logs/%x-%j.out + +# Setup software +module load pytorch/1.10.2-gpu +# Workaround for cudnn module lib path bug +export LD_LIBRARY_PATH=/usr/common/software/sles15_cgpu/cudnn/8.3.2/lib:$LD_LIBRARY_PATH + +# Run the training +srun -l -u python train_pl.py $@ diff --git a/train_pl.py b/train_pl.py new file mode 100644 index 0000000..c27b0b1 --- /dev/null +++ b/train_pl.py @@ -0,0 +1,73 @@ +""" +Main training script for NERSC PyTorch lightning examples +""" + +# System +import os +import argparse +import logging + +# Externals +import yaml +import numpy as np +import pytorch_lightning as pl +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.callbacks import DeviceStatsMonitor + +# Locals +from datasets import get_data_loaders +from modules import get_module +from utils.logging import config_logging + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + add_arg = parser.add_argument + add_arg('config', nargs='?', default='configs/mnist_pl.yaml', + help='YAML configuration file') + return parser.parse_args() + +def load_config(args): + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + config['output_dir'] = os.path.expandvars(config['output_dir']) + return config + +def main(): + """Main training script function""" + + # Initialization + args = parse_args() + + # Load configuration + config = load_config(args) + + # Setup logging + config_logging(verbose=False) + + # Load the datasets + train_data_loader, valid_data_loader = get_data_loaders(**config['data']) + + # Load the PL module + module = get_module(config['module'], config) + + # Prepare callbacks + callbacks = [ + DeviceStatsMonitor(), + ] + + # Create the trainer + pl_logger = pl.loggers.CSVLogger(config['output_dir'], name=config['name']) + num_nodes = os.environ['SLURM_JOB_NUM_NODES'] + trainer = pl.Trainer(gpus=-1, num_nodes=num_nodes, + strategy=DDPStrategy(find_unused_parameters=False), + logger=pl_logger, + enable_progress_bar=False, + callbacks=callbacks, + **config['trainer']) + trainer.fit(module, train_data_loader, valid_data_loader) + + logging.info('All done!') + +if __name__ == '__main__': + main()