Skip to content

Commit 9c4c3d9

Browse files
acrellinstefanv
authored andcommitted
Add optimized hyperparameters to model info in db; format hyperparam display (cesium-ml#139)
1 parent 13a408e commit 9c4c3d9

File tree

6 files changed

+81
-35
lines changed

6 files changed

+81
-35
lines changed

cesium_app/handlers/model.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,56 @@
2020
from distributed.client import _wait
2121

2222

23+
def _build_model_compute_statistics(fset_path, model_type, model_params,
24+
params_to_optimize, model_path):
25+
'''Build model and return summary statistics.
26+
27+
Parameters
28+
----------
29+
fset_path : str
30+
Path to feature set NetCDF file.
31+
model_type : str
32+
Type of model to be built, e.g. 'RandomForestClassifier'.
33+
model_params : dict
34+
Dictionary with hyperparameter values to be used in model building.
35+
Keys are parameter names, values are the associated parameter values.
36+
These hyperparameters will be passed to the model constructor as-is
37+
(for hyperparameter optimization, see `params_to_optimize`).
38+
params_to_optimize : dict or list of dict
39+
During hyperparameter optimization, various model parameters
40+
are adjusted to give an optimal fit. This dictionary gives the
41+
different values that should be explored for each parameter. E.g.,
42+
`{'alpha': [1, 2], 'beta': [4, 5, 6]}` would fit models on all
43+
6 combinations of alpha and beta and compare the resulting models'
44+
goodness-of-fit. If None, only those hyperparameters specified in
45+
`model_parameters` will be used (passed to model constructor as-is).
46+
model_path : str
47+
Path indicating where serialized model will be saved.
48+
49+
Returns
50+
-------
51+
score : float
52+
The model's training score.
53+
best_params : dict
54+
Dictionary of best hyperparameter values (keys are parameter names,
55+
values are the corresponding best values) determined by `scikit-learn`'s
56+
`GridSearchCV`. If no hyperparameter optimization is performed (i.e.
57+
`params_to_optimize` is None or is an empty dict, this will be an empty
58+
dict.
59+
'''
60+
fset = featureset.from_netcdf(fset_path, engine=cfg['xr_engine'])
61+
computed_model = build_model.build_model_from_featureset(
62+
featureset=fset, model_type=model_type,
63+
model_parameters=model_params,
64+
params_to_optimize=params_to_optimize)
65+
score = build_model.score_model(computed_model, fset)
66+
best_params = computed_model.best_params_ if params_to_optimize else {}
67+
joblib.dump(computed_model, model_path)
68+
fset.close()
69+
70+
return score, best_params
71+
72+
2373
class ModelHandler(BaseHandler):
2474
def _get_model(self, model_id):
2575
try:
@@ -42,14 +92,14 @@ def get(self, model_id=None):
4292
return self.success(model_info)
4393

4494
@tornado.gen.coroutine
45-
def _await_model(self, score_future, save_future, model):
95+
def _await_model_statistics(self, model_stats_future, model):
4696
try:
47-
yield save_future._result()
48-
score = yield score_future._result()
97+
score, best_params = yield model_stats_future._result()
4998

5099
model.task_id = None
51100
model.finished = datetime.datetime.now()
52101
model.train_score = score
102+
model.params.update(best_params)
53103
model.save()
54104

55105
self.action('cesium/SHOW_NOTIFICATION',
@@ -97,30 +147,15 @@ def post(self):
97147

98148
executor = yield self._get_executor()
99149

100-
fset = executor.submit(lambda path: featureset.from_netcdf(path,
101-
engine=cfg['xr_engine']), fset.file.uri)
102-
imputed_fset = executor.submit(featureset.Featureset.impute, fset)
103-
computed_model = executor.submit(
104-
build_model.build_model_from_featureset,
105-
featureset=imputed_fset, model_type=model_type,
106-
model_parameters=model_params,
107-
params_to_optimize=params_to_optimize)
108-
score_future = executor.submit(build_model.score_model, computed_model,
109-
imputed_fset)
110-
save_future = executor.submit(joblib.dump, computed_model, model_file.uri)
111-
112-
@tornado.gen.coroutine
113-
def _wait_and_call(callback, *args, futures=[]):
114-
yield _wait(futures_list)
115-
return callback(*args)
116-
117-
model.task_id = save_future.key
150+
model_stats_future = executor.submit(
151+
_build_model_compute_statistics, fset.file.uri, model_type,
152+
model_params, params_to_optimize, model_path)
153+
154+
model.task_id = model_stats_future.key
118155
model.save()
119156

120157
loop = tornado.ioloop.IOLoop.current()
121-
loop.add_callback(_wait_and_call, xr.Dataset.close, imputed_fset,
122-
futures=[computed_model, score_future, save_future])
123-
loop.spawn_callback(self._await_model, score_future, save_future, model)
158+
loop.spawn_callback(self._await_model_statistics, model_stats_future, model)
124159

125160
return self.success(data={'message': "Model training begun."},
126161
action='cesium/FETCH_MODELS')

cesium_app/tests/frontend/test_build_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_model_info_display(driver):
132132
driver.find_element_by_xpath("//td[contains(text(),'{}')]".format(m.name)).click()
133133
assert driver.find_element_by_xpath("//th[contains(text(),'Model Type')]")\
134134
.is_displayed()
135-
assert driver.find_element_by_xpath("//th[contains(text(),'Hyper "
136-
"Parameters')]").is_displayed()
135+
assert driver.find_element_by_xpath("//th[contains(text(),'Hyper"
136+
"parameters')]").is_displayed()
137137
assert driver.find_element_by_xpath("//th[contains(text(),'Training "
138138
"Data Score')]").is_displayed()

cesium_app/tests/frontend/test_datasets.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def test_add_new_dataset(driver):
3535
driver.implicitly_wait(1)
3636
status_td = driver.find_element_by_xpath(
3737
"//div[contains(text(),'Successfully uploaded new dataset')]")
38-
assert test_dataset_name in driver.page_source
3938

4039

4140
def test_dataset_info_display(driver):
@@ -62,4 +61,3 @@ def test_delete_dataset(driver):
6261
driver.implicitly_wait(1)
6362
status_td = driver.find_element_by_xpath(
6463
"//div[contains(text(),'Dataset deleted')]")
65-
assert test_dataset_name not in driver.page_source

cesium_app/tests/frontend/test_pipeline_sequentially.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def test_pipeline_sequentially(driver):
2525
driver.implicitly_wait(1)
2626
status_td = driver.find_element_by_xpath(
2727
"//div[contains(text(),'Added new project')]")
28-
time.sleep(0.1)
29-
assert test_proj_name in driver.page_source
28+
driver.refresh()
3029

3130
# Ensure new project is selected
3231
proj_select = Select(driver.find_element_by_css_selector('[name=project]'))
@@ -54,7 +53,11 @@ def test_pipeline_sequentially(driver):
5453
driver.implicitly_wait(1)
5554
status_td = driver.find_element_by_xpath(
5655
"//div[contains(text(),'Successfully uploaded new dataset')]")
57-
assert test_dataset_name in driver.page_source
56+
driver.refresh()
57+
58+
# Ensure new project is selected
59+
proj_select = Select(driver.find_element_by_css_selector('[name=project]'))
60+
proj_select.select_by_visible_text(test_proj_name)
5861

5962
# Generate new feature set
6063
test_featureset_name = str(uuid.uuid4())

cesium_app/tests/frontend/test_projects.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def test_create_project(driver):
2525
driver.implicitly_wait(1)
2626
status_td = driver.find_element_by_xpath(
2727
"//div[contains(text(),'Added new project')]")
28-
time.sleep(0.1)
29-
assert test_proj_name in driver.page_source
28+
driver.refresh()
3029

3130
proj_select = Select(driver.find_element_by_css_selector('[name=project]'))
3231
proj_select.select_by_visible_text(test_proj_name)

public/scripts/Models.jsx

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ let ModelInfo = props => (
165165
<thead>
166166
<tr>
167167
<th>Model Type</th>
168-
<th>Hyper Parameters</th>
168+
<th>Hyperparameters</th>
169169
<th>Training Data Score</th>
170170
</tr>
171171
</thead>
@@ -175,7 +175,18 @@ let ModelInfo = props => (
175175
{props.model.type}
176176
</td>
177177
<td>
178-
{JSON.stringify(props.model.params, null, 4)}
178+
<table>
179+
<tbody>
180+
{
181+
Object.keys(props.model.params).map(param => (
182+
<tr>
183+
<td>{param}</td>
184+
<td style={{ paddingLeft: "5px" }}>{JSON.stringify(props.model.params[param])}</td>
185+
</tr>
186+
))
187+
}
188+
</tbody>
189+
</table>
179190
</td>
180191
<td>
181192
{props.model.train_score}

0 commit comments

Comments
 (0)