@@ -35,7 +35,7 @@ Keras callback object for SWA.
35
35
36
36
** swa_freq** - Frequency of weight averagining. Used with cyclic schedules.
37
37
38
- ** batch_size** - Batch size. Only needed in the Keras API when using both batch normalization and a data generator .
38
+ ** batch_size** - Batch size model is being trained with (only when using batch normalization) .
39
39
40
40
** verbose** - Verbosity mode, 0 or 1.
41
41
@@ -52,15 +52,15 @@ The default schedule is `'manual'`, allowing the learning rate to be controlled
52
52
53
53
#### Example
54
54
55
- For Keras (with constant LR)
55
+ For Tensorflow Keras (with constant LR)
56
56
``` python
57
- from sklearn.datasets.samples_generator import make_blobs
58
- from keras.utils import to_categorical
59
- from keras.models import Sequential
60
- from keras.layers import Dense
61
- from keras.optimizers import SGD
57
+ from sklearn.datasets import make_blobs
58
+ from tensorflow. keras.utils import to_categorical
59
+ from tensorflow. keras.models import Sequential
60
+ from tensorflow. keras.layers import Dense
61
+ from tensorflow. keras.optimizers import SGD
62
62
63
- from swa.keras import SWA
63
+ from swa.tfkeras import SWA
64
64
65
65
# make dataset
66
66
X, y = make_blobs(n_samples = 1000 ,
@@ -92,16 +92,15 @@ swa = SWA(start_epoch=start_epoch,
92
92
model.fit(X, y, epochs = epochs, verbose = 1 , callbacks = [swa])
93
93
```
94
94
95
- Or for Keras in Tensorflow (with Cyclic LR)
96
-
95
+ Or for Keras (with Cyclic LR)
97
96
``` python
98
- from sklearn.datasets.samples_generator import make_blobs
99
- from tensorflow. keras.utils import to_categorical
100
- from tensorflow. keras.models import Sequential
101
- from tensorflow. keras.layers import Dense, BatchNormalization
102
- from tensorflow. keras.optimizers import SGD
97
+ from sklearn.datasets import make_blobs
98
+ from keras.utils import to_categorical
99
+ from keras.models import Sequential
100
+ from keras.layers import Dense, BatchNormalization
101
+ from keras.optimizers import SGD
103
102
104
- from swa.tfkeras import SWA
103
+ from swa.keras import SWA
105
104
106
105
# make dataset
107
106
X, y = make_blobs(n_samples = 1000 ,
@@ -130,10 +129,11 @@ swa = SWA(start_epoch=start_epoch,
130
129
swa_lr = 0.001 ,
131
130
swa_lr2 = 0.003 ,
132
131
swa_freq = 3 ,
132
+ batch_size = 32 , # needed when using batch norm
133
133
verbose = 1 )
134
134
135
135
# train
136
- model.fit(X, y, epochs = epochs, verbose = 1 , callbacks = [swa])
136
+ model.fit(X, y, batch_size = 32 , epochs = epochs, verbose = 1 , callbacks = [swa])
137
137
```
138
138
139
139
Output
0 commit comments