Skip to content

Commit 73faba1

Browse files
revise tests to include num_workers param and add logic around determining predict func
1 parent 986f276 commit 73faba1

File tree

3 files changed

+64
-69
lines changed

3 files changed

+64
-69
lines changed

elephas/spark_model.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,24 +198,39 @@ def _fit(self, rdd: RDD, **kwargs):
198198
self.stop_server()
199199

200200
def _predict(self, rdd: RDD):
201-
rdd = rdd.zipWithIndex()
202-
if self.num_workers:
203-
rdd = rdd.repartition(self.num_workers)
204201
json_model = self.master_network.to_json()
205202
weights = self.master_network.get_weights()
206203
weights = rdd.context.broadcast(weights)
207-
custom_objects = self.custom_objects
204+
custom_objs = self.custom_objects
208205

209-
def _predict(model, custom_objects, data):
210-
model = model_from_json(model, custom_objects)
206+
def _predict(model_as_json, custom_objects, data):
207+
model = model_from_json(model_as_json, custom_objects)
208+
model.set_weights(weights.value)
209+
data = np.array([x for x in data])
210+
return model.predict(data)
211+
212+
def _predict_with_indices(model_as_json, custom_objects, data):
213+
model = model_from_json(model_as_json, custom_objects)
211214
model.set_weights(weights.value)
212215
data, indices = zip(*data)
213216
data = np.array(data)
214217
return zip(model.predict(data), indices)
215218

216-
predictions_and_indices = rdd.mapPartitions(partial(_predict, json_model, custom_objects))
217-
predictions_sorted_by_index = predictions_and_indices.sortBy(lambda x: x[1])
218-
predictions = predictions_sorted_by_index.map(lambda x: x[0]).collect()
219+
if self.num_workers and self.num_workers > 1:
220+
# if there are multiple workers, we need to retrieve element indices and preserve them throughout
221+
# the inference process, since we'll need to sort by index before returning the result, as repartitioning
222+
# does not preserve ordering, but the users will expect prediction results which correspond to the ordering
223+
# of samples they supplied.
224+
rdd = rdd.zipWithIndex()
225+
rdd = rdd.repartition(self.num_workers)
226+
predictions_and_indices = rdd.mapPartitions(partial(_predict_with_indices, json_model, custom_objs))
227+
predictions_sorted_by_index = predictions_and_indices.sortBy(lambda x: x[1])
228+
predictions = predictions_sorted_by_index.map(lambda x: x[0]).collect()
229+
else:
230+
# if there are no workers specified or only a single worker, we don't need to worry about handling index
231+
# values, since there will be no shuffling
232+
predictions = rdd.mapPartitions(partial(_predict, json_model, custom_objs)).collect()
233+
219234
return predictions
220235

221236
def _evaluate(self, rdd: RDD, **kwargs):

examples/ml_mlp_classification.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Define basic parameters
1616
batch_size = 64
1717
nb_classes = 10
18-
epochs = 1
18+
epochs = 20
1919

2020
# Load data
2121
(x_train, y_train), (x_test, y_test) = mnist.load_data()
@@ -26,6 +26,11 @@
2626
x_test = x_test.astype("float32")
2727
x_train /= 255
2828
x_test /= 255
29+
30+
x_train = x_train[:5000]
31+
x_test = x_test[:1000]
32+
y_train = y_train[:5000]
33+
y_test = y_test[:1000]
2934
print(x_train.shape[0], 'train samples')
3035
print(x_test.shape[0], 'test samples')
3136

@@ -74,7 +79,7 @@
7479
# Evaluate Spark model by evaluating the underlying model
7580
prediction = fitted_pipeline.transform(test_df)
7681
pnl = prediction.select("label", "prediction")
77-
pnl.show(100)
82+
pnl.show(100, truncate=False)
7883

7984
prediction_and_label = pnl.rdd.map(lambda row: (row.label, row.prediction))
8085
metrics = MulticlassMetrics(prediction_and_label)

tests/integration/test_end_to_end.py

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
import pytest
1010
import numpy as np
1111

12-
13-
# enumerate possible combinations for training mode and parameter server for a classification model
14-
@pytest.mark.parametrize('mode,parameter_server_mode', [('synchronous', None),
15-
('asynchronous', 'http'),
16-
('asynchronous', 'socket'),
17-
('hogwild', 'http'),
18-
('hogwild', 'socket')])
19-
def test_training_classification(spark_context, mode, parameter_server_mode, mnist_data, classification_model):
12+
# enumerate possible combinations for training mode and parameter server for a classification model while also validatiing
13+
# multiple workers for repartitioning
14+
@pytest.mark.parametrize('mode,parameter_server_mode,num_workers',
15+
[('synchronous', None, None),
16+
('synchronous', None, 2),
17+
('asynchronous', 'http', None),
18+
('asynchronous', 'http', 2),
19+
('asynchronous', 'socket', None),
20+
('asynchronous', 'socket', 2),
21+
('hogwild', 'http', None),
22+
('hogwild', 'http', 2),
23+
('hogwild', 'socket', None),
24+
('hogwild', 'socket', 2)])
25+
def test_training_classification(spark_context, mode, parameter_server_mode, num_workers, mnist_data, classification_model):
2026
# Define basic parameters
2127
batch_size = 64
2228
epochs = 10
@@ -33,7 +39,7 @@ def test_training_classification(spark_context, mode, parameter_server_mode, mni
3339
rdd = to_simple_rdd(spark_context, x_train, y_train)
3440

3541
# Initialize SparkModel from keras model and Spark context
36-
spark_model = SparkModel(classification_model, frequency='epoch',
42+
spark_model = SparkModel(classification_model, frequency='epoch', num_workers=num_workers,
3743
mode=mode, parameter_server_mode=parameter_server_mode, port=4000 + random.randint(0, 800))
3844

3945
# Train Spark model
@@ -57,13 +63,21 @@ def test_training_classification(spark_context, mode, parameter_server_mode, mni
5763
assert isclose(evals[1], spark_model.master_network.evaluate(x_test, y_test)[1], abs_tol=0.01)
5864

5965

60-
# enumerate possible combinations for training mode and parameter server for a regression model
61-
@pytest.mark.parametrize('mode,parameter_server_mode', [('synchronous', None),
62-
('asynchronous', 'http'),
63-
('asynchronous', 'socket'),
64-
('hogwild', 'http'),
65-
('hogwild', 'socket')])
66-
def test_training_regression(spark_context, mode, parameter_server_mode, boston_housing_dataset, regression_model):
66+
# enumerate possible combinations for training mode and parameter server for a regression model while also validating
67+
# multiple workers for repartitioning
68+
@pytest.mark.parametrize('mode,parameter_server_mode,num_workers',
69+
[('synchronous', None, None),
70+
('synchronous', None, 2),
71+
('asynchronous', 'http', None),
72+
('asynchronous', 'http', 2),
73+
('asynchronous', 'socket', None),
74+
('asynchronous', 'socket', 2),
75+
('hogwild', 'http', None),
76+
('hogwild', 'http', 2),
77+
('hogwild', 'socket', None),
78+
('hogwild', 'socket', 2)])
79+
def test_training_regression(spark_context, mode, parameter_server_mode, num_workers, boston_housing_dataset,
80+
regression_model):
6781
x_train, y_train, x_test, y_test = boston_housing_dataset
6882
rdd = to_simple_rdd(spark_context, x_train, y_train)
6983

@@ -72,7 +86,7 @@ def test_training_regression(spark_context, mode, parameter_server_mode, boston_
7286
epochs = 10
7387
sgd = SGD(lr=0.0000001)
7488
regression_model.compile(sgd, 'mse', ['mae'])
75-
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode,
89+
spark_model = SparkModel(regression_model, frequency='epoch', mode=mode, num_workers=num_workers,
7690
parameter_server_mode=parameter_server_mode, port=4000 + random.randint(0, 800))
7791

7892
# Train Spark model
@@ -92,44 +106,5 @@ def test_training_regression(spark_context, mode, parameter_server_mode, boston_
92106
assert all(np.isclose(x, y, 0.01) for x, y in zip(predictions, spark_model.master_network.predict(x_test)))
93107

94108
# assert we get the same evaluation results when calling evaluate on keras model directly
95-
assert isclose(evals[0], spark_model.master_network.evaluate(x_test, y_test)[0], abs_tol=0.01)
96-
assert isclose(evals[1], spark_model.master_network.evaluate(x_test, y_test)[1], abs_tol=0.01)
97-
98-
99-
def test_bug203_using_multiple_workers(spark_context, boston_housing_dataset, regression_model):
100-
x_train, y_train, x_test, y_test = boston_housing_dataset
101-
rdd = to_simple_rdd(spark_context, x_train, y_train)
102-
103-
# Define basic parameters
104-
batch_size = 32
105-
epochs = 10
106-
sgd = SGD(lr=0.0000001)
107-
import tensorflow as tf
108-
regression_model.compile(sgd, 'mse', ['mae'])
109-
110-
spark_model_multiple_workers = SparkModel(regression_model,
111-
frequency="epoch",
112-
port=4000 + random.randint(0, 800),
113-
mode="synchronous",
114-
num_workers=2)
115-
116-
# Train Spark model
117-
spark_model_multiple_workers.fit(rdd, epochs=epochs, batch_size=batch_size, verbose=0, validation_split=0.1)
118-
119-
# run inference on trained spark model
120-
predictions = spark_model_multiple_workers.predict(x_test)
121-
# run evaluation on trained spark model
122-
evals = spark_model_multiple_workers.evaluate(x_test, y_test)
123-
124-
# assert we can supply rdd and get same prediction results when supplying numpy array
125-
test_rdd = spark_context.parallelize(x_test)
126-
assert all(np.isclose(x, y, 0.01) for x, y in zip(predictions, spark_model_multiple_workers.predict(test_rdd)))
127-
128-
# assert we get the same prediction result with calling predict on keras model directly
129-
assert all(np.isclose(x, y, 0.01) for x, y in zip(predictions, spark_model_multiple_workers.master_network.predict(x_test))), (predictions, spark_model_multiple_workers.master_network.predict(x_test))
130-
131-
# assert we get the same evaluation results when calling evaluate on keras model directly
132-
assert isclose(evals[0], spark_model_multiple_workers.master_network.evaluate(x_test, y_test)[0], abs_tol=1.0)
133-
assert isclose(evals[1], spark_model_multiple_workers.master_network.evaluate(x_test, y_test)[1], abs_tol=1.0)
134-
135-
109+
assert isclose(evals[0], spark_model.master_network.evaluate(x_test, y_test)[0], abs_tol=1.0)
110+
assert isclose(evals[1], spark_model.master_network.evaluate(x_test, y_test)[1], abs_tol=1.0)

0 commit comments

Comments
 (0)