Skip to content

Commit 182e121

Browse files
Update README.md
1 parent d4d27c3 commit 182e121

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Keras callback object for SWA.
3535

3636
**swa_freq** - Frequency of weight averagining. Used with cyclic schedules.
3737

38-
**batch_size** - Batch size. Only needed in the Keras API when using both batch normalization and a data generator.
38+
**batch_size** - Batch size model is being trained with (only when using batch normalization).
3939

4040
**verbose** - Verbosity mode, 0 or 1.
4141

@@ -52,15 +52,15 @@ The default schedule is `'manual'`, allowing the learning rate to be controlled
5252

5353
#### Example
5454

55-
For Keras (with constant LR)
55+
For Tensorflow Keras (with constant LR)
5656
```python
57-
from sklearn.datasets.samples_generator import make_blobs
58-
from keras.utils import to_categorical
59-
from keras.models import Sequential
60-
from keras.layers import Dense
61-
from keras.optimizers import SGD
57+
from sklearn.datasets import make_blobs
58+
from tensorflow.keras.utils import to_categorical
59+
from tensorflow.keras.models import Sequential
60+
from tensorflow.keras.layers import Dense
61+
from tensorflow.keras.optimizers import SGD
6262

63-
from swa.keras import SWA
63+
from swa.tfkeras import SWA
6464

6565
# make dataset
6666
X, y = make_blobs(n_samples=1000,
@@ -92,16 +92,15 @@ swa = SWA(start_epoch=start_epoch,
9292
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])
9393
```
9494

95-
Or for Keras in Tensorflow (with Cyclic LR)
96-
95+
Or for Keras (with Cyclic LR)
9796
```python
98-
from sklearn.datasets.samples_generator import make_blobs
99-
from tensorflow.keras.utils import to_categorical
100-
from tensorflow.keras.models import Sequential
101-
from tensorflow.keras.layers import Dense, BatchNormalization
102-
from tensorflow.keras.optimizers import SGD
97+
from sklearn.datasets import make_blobs
98+
from keras.utils import to_categorical
99+
from keras.models import Sequential
100+
from keras.layers import Dense, BatchNormalization
101+
from keras.optimizers import SGD
103102

104-
from swa.tfkeras import SWA
103+
from swa.keras import SWA
105104

106105
# make dataset
107106
X, y = make_blobs(n_samples=1000,
@@ -130,10 +129,11 @@ swa = SWA(start_epoch=start_epoch,
130129
swa_lr=0.001,
131130
swa_lr2=0.003,
132131
swa_freq=3,
132+
batch_size=32, # needed when using batch norm
133133
verbose=1)
134134

135135
# train
136-
model.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])
136+
model.fit(X, y, batch_size=32, epochs=epochs, verbose=1, callbacks=[swa])
137137
```
138138

139139
Output

0 commit comments

Comments
 (0)