Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support and example for pytorch-lightning #5

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions configs/mnist_pl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module: basic
output_dir: $SCRATCH/pytorch-examples/output
name: mnist_pl

data:
name: mnist
data_path: $SCRATCH/pytorch-examples/mnist/data
num_workers: 8
batch_size: 64

model:
name: cnn
input_shape: [1, 28, 28]
conv_sizes: [8, 16]
fc_sizes: [32]
output_size: 10

loss:
name: CrossEntropyLoss

optimizer:
name: Adam
lr: 0.001

metrics:
acc: Accuracy

trainer:
max_epochs: 4
13 changes: 13 additions & 0 deletions modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Python module for holding our PyTorch Lightning modules.

Trainers here inherit from the BaseTrainer and implement the logic for
constructing the model as well as training and evaluation.
"""

import importlib

def get_module(name, config):
"""Factory function for constructing a Trainer"""
module = importlib.import_module('.' + name, 'modules')
return module.get_module(config)
50 changes: 50 additions & 0 deletions modules/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# System
import logging

# Externals
import torch
import pytorch_lightning as pl

# Locals
from models import get_model


class BasicModule(pl.LightningModule):
"""PL Module for basic single-model examples"""

def __init__(self, config):
super().__init__()
self.config = config

# Construct the model
self.model = get_model(**config['model'])

# Construct the loss function
loss_config = config['loss']
Loss = getattr(torch.nn, loss_config.pop('name'))
self.loss_func = Loss(**loss_config)

def configure_optimizers(self):
logging.info('configure_optimizers')
optimizer_config = self.config['optimizer']
Optim = getattr(torch.optim, optimizer_config.pop('name'))
return Optim(self.model.parameters(), **optimizer_config)

def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
batch_input, batch_target = batch
batch_output = self.model(batch_input)
batch_loss = self.loss_func(batch_output, batch_target)
self.log('train_loss', batch_loss)
return batch_loss

def validation_step(self, batch, batch_idx):
batch_input, batch_target = batch
batch_output = self.model(batch_input)
batch_loss = self.loss_func(batch_output, batch_target)
self.log("valid_loss", batch_loss)

def get_module(config):
return BasicModule(config)
17 changes: 17 additions & 0 deletions scripts/train_pl_cgpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
#SBATCH -C gpu
#SBATCH --cpus-per-task=10
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-task=1
#SBATCH --gpu-bind=none
#SBATCH --time 30
#SBATCH -J train-cgpu
#SBATCH -o logs/%x-%j.out

# Setup software
module load pytorch/1.10.2-gpu
# Workaround for cudnn module lib path bug
export LD_LIBRARY_PATH=/usr/common/software/sles15_cgpu/cudnn/8.3.2/lib:$LD_LIBRARY_PATH

# Run the training
srun -l -u python train_pl.py $@
73 changes: 73 additions & 0 deletions train_pl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Main training script for NERSC PyTorch lightning examples
"""

# System
import os
import argparse
import logging

# Externals
import yaml
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import DeviceStatsMonitor

# Locals
from datasets import get_data_loaders
from modules import get_module
from utils.logging import config_logging

def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
add_arg = parser.add_argument
add_arg('config', nargs='?', default='configs/mnist_pl.yaml',
help='YAML configuration file')
return parser.parse_args()

def load_config(args):
with open(args.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config['output_dir'] = os.path.expandvars(config['output_dir'])
return config

def main():
"""Main training script function"""

# Initialization
args = parse_args()

# Load configuration
config = load_config(args)

# Setup logging
config_logging(verbose=False)

# Load the datasets
train_data_loader, valid_data_loader = get_data_loaders(**config['data'])

# Load the PL module
module = get_module(config['module'], config)

# Prepare callbacks
callbacks = [
DeviceStatsMonitor(),
]

# Create the trainer
pl_logger = pl.loggers.CSVLogger(config['output_dir'], name=config['name'])
num_nodes = os.environ['SLURM_JOB_NUM_NODES']
trainer = pl.Trainer(gpus=-1, num_nodes=num_nodes,
strategy=DDPStrategy(find_unused_parameters=False),
logger=pl_logger,
enable_progress_bar=False,
callbacks=callbacks,
**config['trainer'])
trainer.fit(module, train_data_loader, valid_data_loader)

logging.info('All done!')

if __name__ == '__main__':
main()