20
20
from distributed .client import _wait
21
21
22
22
23
+ def _build_model_compute_statistics (fset_path , model_type , model_params ,
24
+ params_to_optimize , model_path ):
25
+ '''Build model and return summary statistics.
26
+
27
+ Parameters
28
+ ----------
29
+ fset_path : str
30
+ Path to feature set NetCDF file.
31
+ model_type : str
32
+ Type of model to be built, e.g. 'RandomForestClassifier'.
33
+ model_params : dict
34
+ Dictionary with hyperparameter values to be used in model building.
35
+ Keys are parameter names, values are the associated parameter values.
36
+ These hyperparameters will be passed to the model constructor as-is
37
+ (for hyperparameter optimization, see `params_to_optimize`).
38
+ params_to_optimize : dict or list of dict
39
+ During hyperparameter optimization, various model parameters
40
+ are adjusted to give an optimal fit. This dictionary gives the
41
+ different values that should be explored for each parameter. E.g.,
42
+ `{'alpha': [1, 2], 'beta': [4, 5, 6]}` would fit models on all
43
+ 6 combinations of alpha and beta and compare the resulting models'
44
+ goodness-of-fit. If None, only those hyperparameters specified in
45
+ `model_parameters` will be used (passed to model constructor as-is).
46
+ model_path : str
47
+ Path indicating where serialized model will be saved.
48
+
49
+ Returns
50
+ -------
51
+ score : float
52
+ The model's training score.
53
+ best_params : dict
54
+ Dictionary of best hyperparameter values (keys are parameter names,
55
+ values are the corresponding best values) determined by `scikit-learn`'s
56
+ `GridSearchCV`. If no hyperparameter optimization is performed (i.e.
57
+ `params_to_optimize` is None or is an empty dict, this will be an empty
58
+ dict.
59
+ '''
60
+ fset = featureset .from_netcdf (fset_path , engine = cfg ['xr_engine' ])
61
+ computed_model = build_model .build_model_from_featureset (
62
+ featureset = fset , model_type = model_type ,
63
+ model_parameters = model_params ,
64
+ params_to_optimize = params_to_optimize )
65
+ score = build_model .score_model (computed_model , fset )
66
+ best_params = computed_model .best_params_ if params_to_optimize else {}
67
+ joblib .dump (computed_model , model_path )
68
+ fset .close ()
69
+
70
+ return score , best_params
71
+
72
+
23
73
class ModelHandler (BaseHandler ):
24
74
def _get_model (self , model_id ):
25
75
try :
@@ -42,14 +92,14 @@ def get(self, model_id=None):
42
92
return self .success (model_info )
43
93
44
94
@tornado .gen .coroutine
45
- def _await_model (self , score_future , save_future , model ):
95
+ def _await_model_statistics (self , model_stats_future , model ):
46
96
try :
47
- yield save_future ._result ()
48
- score = yield score_future ._result ()
97
+ score , best_params = yield model_stats_future ._result ()
49
98
50
99
model .task_id = None
51
100
model .finished = datetime .datetime .now ()
52
101
model .train_score = score
102
+ model .params .update (best_params )
53
103
model .save ()
54
104
55
105
self .action ('cesium/SHOW_NOTIFICATION' ,
@@ -97,30 +147,15 @@ def post(self):
97
147
98
148
executor = yield self ._get_executor ()
99
149
100
- fset = executor .submit (lambda path : featureset .from_netcdf (path ,
101
- engine = cfg ['xr_engine' ]), fset .file .uri )
102
- imputed_fset = executor .submit (featureset .Featureset .impute , fset )
103
- computed_model = executor .submit (
104
- build_model .build_model_from_featureset ,
105
- featureset = imputed_fset , model_type = model_type ,
106
- model_parameters = model_params ,
107
- params_to_optimize = params_to_optimize )
108
- score_future = executor .submit (build_model .score_model , computed_model ,
109
- imputed_fset )
110
- save_future = executor .submit (joblib .dump , computed_model , model_file .uri )
111
-
112
- @tornado .gen .coroutine
113
- def _wait_and_call (callback , * args , futures = []):
114
- yield _wait (futures_list )
115
- return callback (* args )
116
-
117
- model .task_id = save_future .key
150
+ model_stats_future = executor .submit (
151
+ _build_model_compute_statistics , fset .file .uri , model_type ,
152
+ model_params , params_to_optimize , model_path )
153
+
154
+ model .task_id = model_stats_future .key
118
155
model .save ()
119
156
120
157
loop = tornado .ioloop .IOLoop .current ()
121
- loop .add_callback (_wait_and_call , xr .Dataset .close , imputed_fset ,
122
- futures = [computed_model , score_future , save_future ])
123
- loop .spawn_callback (self ._await_model , score_future , save_future , model )
158
+ loop .spawn_callback (self ._await_model_statistics , model_stats_future , model )
124
159
125
160
return self .success (data = {'message' : "Model training begun." },
126
161
action = 'cesium/FETCH_MODELS' )
0 commit comments