Skip to content

Commit 426e7cc

Browse files
Tweak example
1 parent 73faba1 commit 426e7cc

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

examples/ml_mlp_regression.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@
1212

1313

1414
# Define basic parameters
15-
batch_size = 64
16-
epochs = 1
15+
batch_size = 16
16+
epochs = 100
1717

1818
# Load data
1919
(x_train, y_train), (x_test, y_test) = boston_housing.load_data()
2020

2121
x_train = x_train.astype("float32")
2222
x_test = x_test.astype("float32")
23-
x_train /= 255
24-
x_test /= 255
2523
print(x_train.shape[0], 'train samples')
2624
print(x_test.shape[0], 'test samples')
2725

@@ -41,7 +39,7 @@
4139
df = to_data_frame(sc, x_train, y_train)
4240
test_df = to_data_frame(sc, x_test, y_test)
4341

44-
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
42+
sgd = optimizers.SGD(lr=0.000001)
4543
sgd_conf = optimizers.serialize(sgd)
4644

4745
# Initialize Spark ML Estimator

0 commit comments

Comments
 (0)