Skip to content

Commit fd3b5c8

Browse files
authored
Make both torch.amp and apex.amp available as backend for mixed precision training (#91)
* UPDATE: add `torch.amp` as a new backend for mixed precision training User can choose the backend for mixed precision training by specifying keyword argument `amp_backend` to `LRFinder` now. * MAINT: remove code for delaying gradient unscaling when gradient accumulation is enabled Since further advanced tricks for gradient accumulation can be done by overriding `LRFinder._train_batch()`, it seems it's not necessary to do it by our own. Also, removing it can make less surprises once there is overflow while training in lower precision. See also this section in `apex` document: https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations * MAINT: replace f-string with string.format for backward compatibility
1 parent 98f7a89 commit fd3b5c8

File tree

4 files changed

+369
-42
lines changed

4 files changed

+369
-42
lines changed

README.md

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,40 @@ lr_finder.reset()
105105

106106
### Mixed precision training
107107

108-
Currently, we use [`apex`](https://github.com/NVIDIA/apex) as the dependency for mixed precision training.
109-
To enable mixed precision training, you just need to call `amp.initialize()` before running `LRFinder`. e.g.
110-
111-
```python
112-
from torch_lr_finder import LRFinder
113-
from apex import amp
114-
115-
# Add this line before running `LRFinder`
116-
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
117-
118-
lr_finder = LRFinder(model, optimizer, criterion, device='cuda')
119-
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp')
120-
lr_finder.plot()
121-
lr_finder.reset()
122-
```
108+
Both `apex.amp` and `torch.amp` are supported now, here are the examples:
109+
110+
- Using [`apex.amp`](https://github.com/NVIDIA/apex):
111+
```python
112+
from torch_lr_finder import LRFinder
113+
from apex import amp
114+
115+
# Add this line before running `LRFinder`
116+
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
117+
118+
lr_finder = LRFinder(model, optimizer, criterion, device='cuda', amp_backend='apex')
119+
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp')
120+
lr_finder.plot()
121+
lr_finder.reset()
122+
```
123+
124+
- Using [`torch.amp`](https://pytorch.org/docs/stable/notes/amp_examples.html)
125+
```python
126+
from torch_lr_finder import LRFinder
127+
128+
amp_config = {
129+
'device_type': 'cuda',
130+
'dtype': torch.float16,
131+
}
132+
grad_scaler = torch.cuda.amp.GradScaler()
133+
134+
lr_finder = LRFinder(
135+
model, optimizer, criterion, device='cuda',
136+
amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler
137+
)
138+
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp')
139+
lr_finder.plot()
140+
lr_finder.reset()
141+
```
123142

124143
Note that the benefit of mixed precision training requires a nvidia GPU with tensor cores (see also: [NVIDIA/apex #297](https://github.com/NVIDIA/apex/issues/297))
125144

examples/mnist_with_amp.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""
2+
Train a simple neural net for MNIST dataset with mixed precision training.
3+
4+
Examples
5+
--------
6+
- Run with `torch.amp`:
7+
```bash
8+
$ python mnist_with_amp.py --batch_size=32 --seed=42 --tqdm --amp_backend=torch
9+
```
10+
- Run without mixed precision training:
11+
```bash
12+
$ python mnist_with_amp.py --batch_size=32 --seed=42 --tqdm --amp_backend=""
13+
```
14+
"""
15+
16+
from argparse import ArgumentParser
17+
import random
18+
import sys
19+
import os
20+
import time
21+
22+
import numpy as np
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
import torch.optim as optim
27+
from torch.utils.data import Subset, DataLoader
28+
from torchvision import datasets, transforms
29+
30+
from torch_lr_finder import LRFinder
31+
from apex import amp
32+
33+
34+
SEED = 0
35+
36+
def reset_seed(seed):
37+
"""
38+
ref: https://forums.fast.ai/t/accumulating-gradients/33219/28
39+
"""
40+
random.seed(seed)
41+
os.environ['PYTHONHASHSEED'] = str(seed)
42+
np.random.seed(seed)
43+
torch.manual_seed(seed)
44+
torch.cuda.manual_seed(seed)
45+
torch.backends.cudnn.deterministic = True
46+
47+
48+
def simple_timer(func):
49+
def wrapper(*args, **kwargs):
50+
st = time.time()
51+
func(*args, **kwargs)
52+
print('--- Time taken from {}: {} seconds'.format(
53+
func.__qualname__, time.time() - st
54+
))
55+
return wrapper
56+
57+
58+
# redirect output from tqdm
59+
def conceal_stdout(enabled):
60+
if enabled:
61+
f = open(os.devnull, 'w')
62+
sys.stdout = f
63+
sys.stderr = f
64+
else:
65+
sys.stdout = sys.__stdout__
66+
sys.stderr = sys.__stderr__
67+
68+
69+
class ConvNet(nn.Module):
70+
def __init__(self):
71+
super(ConvNet, self).__init__()
72+
self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1)
73+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
74+
self.conv2_drop = nn.Dropout2d()
75+
self.net = nn.Sequential(
76+
self.conv1, # (24, 24, 16)
77+
nn.MaxPool2d(2), # (12, 12, 16)
78+
nn.ReLU(True),
79+
self.conv2, # (10, 10, 32)
80+
self.conv2_drop,
81+
nn.MaxPool2d(2), # (5, 5, 32)
82+
nn.ReLU(True),
83+
)
84+
self.fc1 = nn.Linear(5*5*32, 64)
85+
self.fc2 = nn.Linear(64, 16)
86+
87+
def forward(self, x):
88+
x = self.net(x)
89+
x = x.view(-1, 5*5*32)
90+
x = F.relu(self.fc1(x))
91+
x = F.dropout(x, training=self.training)
92+
x = self.fc2(x)
93+
return F.log_softmax(x, dim=1)
94+
95+
96+
@simple_timer
97+
def warm_up(trainset):
98+
trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
99+
100+
device = torch.device('cuda')
101+
model = ConvNet()
102+
model = model.to(device)
103+
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
104+
criterion = nn.NLLLoss()
105+
106+
conceal_stdout(True)
107+
lr_finder = LRFinder(model, optimizer, criterion, device='cuda')
108+
lr_finder.range_test(trainloader, end_lr=10, num_iter=10, step_mode='exp')
109+
conceal_stdout(False)
110+
111+
112+
@simple_timer
113+
def run_normal(trainset, batch_size, no_tqdm=True):
114+
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
115+
116+
device = torch.device('cuda')
117+
model = ConvNet()
118+
model = model.to(device)
119+
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
120+
criterion = nn.NLLLoss()
121+
122+
conceal_stdout(no_tqdm)
123+
lr_finder = LRFinder(model, optimizer, criterion, device='cuda')
124+
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp')
125+
lr_finder.plot()
126+
conceal_stdout(no_tqdm and False)
127+
128+
129+
@simple_timer
130+
def run_amp_apex(trainset, batch_size, no_tqdm=True, opt_level='O1'):
131+
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
132+
133+
device = torch.device('cuda')
134+
model = ConvNet()
135+
model = model.to(device)
136+
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
137+
criterion = nn.NLLLoss()
138+
139+
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
140+
141+
conceal_stdout(no_tqdm)
142+
lr_finder = LRFinder(model, optimizer, criterion, device='cuda', amp_backend='apex')
143+
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp')
144+
lr_finder.plot()
145+
conceal_stdout(no_tqdm and False)
146+
147+
@simple_timer
148+
def run_amp_torch(trainset, batch_size, no_tqdm=True):
149+
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
150+
151+
device = torch.device('cuda')
152+
model = ConvNet()
153+
model = model.to(device)
154+
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
155+
criterion = nn.NLLLoss()
156+
157+
amp_config = {
158+
'device_type': 'cuda',
159+
'dtype': torch.float16,
160+
}
161+
grad_scaler = torch.cuda.amp.GradScaler()
162+
163+
conceal_stdout(no_tqdm)
164+
lr_finder = LRFinder(
165+
model, optimizer, criterion,
166+
amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler
167+
)
168+
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp')
169+
lr_finder.plot()
170+
conceal_stdout(no_tqdm and False)
171+
172+
def parse_args():
173+
parser = ArgumentParser(add_help=True)
174+
parser.add_argument('--amp_backend', type=str, default='',
175+
help='Backend for auto-mixed precision training, available: '
176+
'[torch, apex]. If not specified, amp is disabled.')
177+
parser.add_argument('--batch_size', type=int, default=32)
178+
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
179+
parser.add_argument('--data_folder', type=str, default='./data',
180+
help='Location of MNIST dataset.')
181+
parser.add_argument('--cudnn_benchmark', action='store_true',
182+
help='Add this flag to make cudnn auto-tuner able to find '
183+
'the best algorithm on your machine. This may improve the '
184+
'performance when you are running script of mixed precision '
185+
'training.')
186+
parser.add_argument('--tqdm', action='store_true',
187+
help='Add this flag to show the output from tqdm.')
188+
parser.add_argument('--warm_up', action='store_true',
189+
help='Add this flag to run a warm-up snippet.')
190+
parser.add_argument('--opt_level', type=str, default='O1',
191+
help='Optimization level for amp. (works only for `apex`)')
192+
return parser.parse_args()
193+
194+
195+
if __name__ == '__main__':
196+
args = parse_args()
197+
198+
# turn this mode on may improve the performance on some GPUs
199+
torch.backends.cudnn.benchmark = args.cudnn_benchmark
200+
201+
transform = transforms.Compose([
202+
transforms.ToTensor(),
203+
transforms.Normalize((0.1307,), (0.3081,))
204+
])
205+
trainset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform)
206+
207+
reset_seed(args.seed)
208+
if args.warm_up:
209+
warm_up(trainset)
210+
211+
if args.amp_backend == '':
212+
run_normal(trainset, args.batch_size, no_tqdm=not args.tqdm)
213+
elif args.amp_backend == 'apex':
214+
run_amp_apex(trainset, args.batch_size, no_tqdm=not args.tqdm, opt_level=args.opt_level)
215+
elif args.amp_backend == 'torch':
216+
run_amp_torch(trainset, args.batch_size, no_tqdm=not args.tqdm)
217+
else:
218+
print('Unknown amp backend: {}'.format(args.amp_backend))
219+

0 commit comments

Comments
 (0)