diff --git a/doc/whats_new/v0.9.rst b/doc/whats_new/v0.9.rst index 42bec8647..2085fe0b3 100644 --- a/doc/whats_new/v0.9.rst +++ b/doc/whats_new/v0.9.rst @@ -7,3 +7,10 @@ Version 0.9.0 ============= **In Development** + +:mod:`skopt.searchcv` +--------------------- +- |Fix| Fix :obj:`skopt.searchcv.BayesSearchCV` for scikit-learn >= 0.24. + :pr:`988` +- |API| Deprecate :class:`skopt.searchcv.BayesSearchCV` parameter `iid=`. + :pr:`988` diff --git a/skopt/searchcv.py b/skopt/searchcv.py index 342952c22..a37bc79ad 100644 --- a/skopt/searchcv.py +++ b/skopt/searchcv.py @@ -1,21 +1,17 @@ +import warnings + try: from collections.abc import Sized except ImportError: from collections import Sized -from collections import defaultdict -from functools import partial import numpy as np from scipy.stats import rankdata -import sklearn -from sklearn.base import is_classifier, clone -from joblib import Parallel, delayed from sklearn.model_selection._search import BaseSearchCV from sklearn.utils import check_random_state -from sklearn.utils.fixes import MaskedArray -from sklearn.utils.validation import indexable, check_is_fitted +from sklearn.utils.validation import check_is_fitted try: from sklearn.metrics import check_scoring except ImportError: @@ -115,11 +111,6 @@ class BayesSearchCV(BaseSearchCV): - A string, giving an expression as a function of n_jobs, as in '2*n_jobs' - iid : boolean, default=True - If True, the data is assumed to be identically distributed across - the folds, and the loss minimized is the total loss per sample, - and not the mean loss across the folds. - cv : int, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy. Possible inputs for cv are: @@ -289,7 +280,7 @@ class BayesSearchCV(BaseSearchCV): def __init__(self, estimator, search_spaces, optimizer_kwargs=None, n_iter=50, scoring=None, fit_params=None, n_jobs=1, - n_points=1, iid=True, refit=True, cv=None, verbose=0, + n_points=1, iid='deprecated', refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', random_state=None, error_score='raise', return_train_score=False): @@ -305,9 +296,14 @@ def __init__(self, estimator, search_spaces, optimizer_kwargs=None, # in the constructor and be passed in ``fit``. self.fit_params = fit_params + if iid != "deprecated": + warnings.warn("The `iid` parameter has been deprecated " + "and will be ignored.") + self.iid = iid # For sklearn repr pprint + super(BayesSearchCV, self).__init__( estimator=estimator, scoring=scoring, - n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, + n_jobs=n_jobs, refit=refit, cv=cv, verbose=verbose, pre_dispatch=pre_dispatch, error_score=error_score, return_train_score=return_train_score) @@ -364,170 +360,11 @@ def _check_search_space(self, search_space): "Search space should be provided as a dict or list of dict," "got %s" % search_space) - # copied for compatibility with 0.19 sklearn from 0.18 BaseSearchCV - @property - def best_score_(self): - check_is_fitted(self, 'cv_results_') - return self.cv_results_['mean_test_score'][self.best_index_] - - # copied for compatibility with 0.19 sklearn from 0.18 BaseSearchCV - @property - def best_params_(self): - check_is_fitted(self, 'cv_results_') - return self.cv_results_['params'][self.best_index_] - @property def optimizer_results_(self): check_is_fitted(self, '_optim_results') return self._optim_results - # copied for compatibility with 0.19 sklearn from 0.18 BaseSearchCV - def _fit(self, X, y, groups, parameter_iterable): - """ - Actual fitting, performing the search over parameters. - Taken from https://github.com/scikit-learn/scikit-learn/blob/0.18.X - .../sklearn/model_selection/_search.py - """ - estimator = self.estimator - cv = sklearn.model_selection._validation.check_cv( - self.cv, y, classifier=is_classifier(estimator)) - self.scorer_ = check_scoring( - self.estimator, scoring=self.scoring) - - X, y, groups = indexable(X, y, groups) - n_splits = cv.get_n_splits(X, y, groups) - if self.verbose > 0 and isinstance(parameter_iterable, Sized): - n_candidates = len(parameter_iterable) - print("Fitting {0} folds for each of {1} candidates, totalling" - " {2} fits".format(n_splits, n_candidates, - n_candidates * n_splits)) - - base_estimator = clone(self.estimator) - pre_dispatch = self.pre_dispatch - - cv_iter = list(cv.split(X, y, groups)) - out = Parallel( - n_jobs=self.n_jobs, verbose=self.verbose, - pre_dispatch=pre_dispatch - )(delayed(sklearn.model_selection._validation._fit_and_score)( - clone(base_estimator), - X, y, self.scorer_, - train, test, self.verbose, parameters, - fit_params=self.fit_params, - return_train_score=self.return_train_score, - return_n_test_samples=True, - return_times=True, return_parameters=True, - error_score=self.error_score - ) - for parameters in parameter_iterable - for train, test in cv_iter) - - # if one choose to see train score, "out" will contain train score info - if self.return_train_score: - (train_scores, test_scores, test_sample_counts, - fit_time, score_time, parameters) = zip(*out) - else: - (test_scores, test_sample_counts, - fit_time, score_time, parameters) = zip(*out) - - candidate_params = parameters[::n_splits] - n_candidates = len(candidate_params) - - results = dict() - - def _store(key_name, array, weights=None, splits=False, rank=False): - """A small helper to store the scores/times to the cv_results_""" - array = np.array(array, dtype=np.float64).reshape(n_candidates, - n_splits) - if splits: - for split_i in range(n_splits): - results["split%d_%s" - % (split_i, key_name)] = array[:, split_i] - - array_means = np.average(array, axis=1, weights=weights) - results['mean_%s' % key_name] = array_means - # Weighted std is not directly available in numpy - array_stds = np.sqrt(np.average((array - - array_means[:, np.newaxis]) ** 2, - axis=1, weights=weights)) - results['std_%s' % key_name] = array_stds - - if rank: - results["rank_%s" % key_name] = np.asarray( - rankdata(-array_means, method='min'), dtype=np.int32) - - # Computed the (weighted) mean and std for test scores alone - # NOTE test_sample counts (weights) remain the same for all candidates - test_sample_counts = np.array(test_sample_counts[:n_splits], - dtype=np.int) - - _store('test_score', test_scores, splits=True, rank=True, - weights=test_sample_counts if self.iid else None) - if self.return_train_score: - _store('train_score', train_scores, splits=True) - _store('fit_time', fit_time) - _store('score_time', score_time) - - best_index = np.flatnonzero(results["rank_test_score"] == 1)[0] - best_parameters = candidate_params[best_index] - - # Use one MaskedArray and mask all the places where the param is not - # applicable for that candidate. Use defaultdict as each candidate may - # not contain all the params - param_results = defaultdict(partial(np.ma.array, - np.empty(n_candidates,), - mask=True, - dtype=object)) - for cand_i, params in enumerate(candidate_params): - for name, value in params.items(): - # An all masked empty array gets created for the key - # `"param_%s" % name` at the first occurence of `name`. - # Setting the value at an index also unmasks that index - param_results["param_%s" % name][cand_i] = value - - results.update(param_results) - - # Store a list of param dicts at the key 'params' - results['params'] = candidate_params - - self.cv_results_ = results - self.best_index_ = best_index - self.n_splits_ = n_splits - - if self.refit: - # fit the best estimator using the entire dataset - # clone first to work around broken estimators - best_estimator = clone(base_estimator).set_params( - **best_parameters) - if y is not None: - best_estimator.fit(X, y, **self.fit_params) - else: - best_estimator.fit(X, **self.fit_params) - self.best_estimator_ = best_estimator - return self - - def _fit_best_model(self, X, y): - """Fit the estimator copy with best parameters found to the - provided data. - - Parameters - ---------- - X : array-like, shape = [n_samples, n_features] - Input data, where n_samples is the number of samples and - n_features is the number of features. - - y : array-like, shape = [n_samples] or [n_samples, n_output], - Target relative to X for classification or regression. - - Returns - ------- - self - """ - self.best_estimator_ = clone(self.estimator) - self.best_estimator_.set_params(**self.best_params_) - self.best_estimator_.fit(X, y, **(self.fit_params or {})) - return self - def _make_optimizer(self, params_space): """Instantiate skopt Optimizer class. @@ -556,10 +393,9 @@ def _make_optimizer(self, params_space): return optimizer - def _step(self, X, y, search_space, optimizer, groups=None, n_points=1): + def _step(self, search_space, optimizer, evaluate_candidates, n_points=1): """Generate n_jobs parameters and evaluate them in parallel. """ - # get parameter values to evaluate params = optimizer.ask(n_points=n_points) @@ -569,33 +405,10 @@ def _step(self, X, y, search_space, optimizer, groups=None, n_points=1): # make lists into dictionaries params_dict = [point_asdict(search_space, p) for p in params] - # HACK: self.cv_results_ is reset at every call to _fit, keep current - all_cv_results = self.cv_results_ - - # HACK: this adds compatibility with different versions of sklearn - refit = self.refit - self.refit = False - self._fit(X, y, groups, params_dict) - self.refit = refit - - # merge existing and new cv_results_ - for k in self.cv_results_: - all_cv_results[k].extend(self.cv_results_[k]) - - all_cv_results["rank_test_score"] = list(np.asarray( - rankdata(-np.array(all_cv_results['mean_test_score']), - method='min'), dtype=np.int32)) - if self.return_train_score: - all_cv_results["rank_train_score"] = list(np.asarray( - rankdata(-np.array(all_cv_results['mean_train_score']), - method='min'), dtype=np.int32)) - self.cv_results_ = all_cv_results - self.best_index_ = np.argmax(self.cv_results_['mean_test_score']) - - # feed the point and objective back into optimizer - local_results = self.cv_results_['mean_test_score'][-len(params):] - - # optimizer minimizes objective, hence provide negative score + all_results = evaluate_candidates(params_dict) + # Feed the point and objective value back into optimizer + # Optimizer minimizes objective, hence provide negative score + local_results = all_results["mean_test_score"][-len(params):] return optimizer.tell(params, [-score for score in local_results]) @property @@ -621,10 +434,8 @@ def total_iterations(self): return total_iter - def _run_search(self, x): - pass - - def fit(self, X, y=None, groups=None, callback=None): + # TODO: Accept callbacks via the constructor? + def fit(self, X, y=None, *, groups=None, callback=None, **fit_params): """Run fit on the estimator with randomly drawn parameters. Parameters @@ -645,18 +456,31 @@ def fit(self, X, y=None, groups=None, callback=None): combination tested. If list of callables, then each callable in the list is called. """ + self._callbacks = check_callback(callback) + if self.optimizer_kwargs is None: + self.optimizer_kwargs_ = {} + else: + self.optimizer_kwargs_ = dict(self.optimizer_kwargs) + + super().fit(X=X, y=y, groups=groups, **fit_params) + + # BaseSearchCV never ranked train scores, + # but apparently we used to ship this (back-compat) + if self.return_train_score: + self.cv_results_["rank_train_score"] = \ + rankdata(-np.array(self.cv_results_["mean_train_score"]), + method='min').astype(int) + return self + + def _run_search(self, evaluate_candidates): # check if space is a single dict, convert to list if so search_spaces = self.search_spaces if isinstance(search_spaces, dict): search_spaces = [search_spaces] - callbacks = check_callback(callback) + callbacks = self._callbacks - if self.optimizer_kwargs is None: - self.optimizer_kwargs_ = {} - else: - self.optimizer_kwargs_ = dict(self.optimizer_kwargs) random_state = check_random_state(self.random_state) self.optimizer_kwargs_['random_state'] = random_state @@ -668,9 +492,6 @@ def fit(self, X, y=None, groups=None, callback=None): optimizers.append(self._make_optimizer(search_space)) self.optimizers_ = optimizers # will save the states of the optimizers - self.cv_results_ = defaultdict(list) - self.best_index_ = None - self.multimetric_ = False self._optim_results = [] n_points = self.n_points @@ -689,17 +510,11 @@ def fit(self, X, y=None, groups=None, callback=None): n_points_adjusted = min(n_iter, n_points) optim_result = self._step( - X, y, search_space, optimizer, - groups=groups, n_points=n_points_adjusted + search_space, optimizer, + evaluate_candidates, n_points=n_points_adjusted ) n_iter -= n_points if eval_callbacks(callbacks, optim_result): break self._optim_results.append(optim_result) - - # Refit the best model on the the whole dataset - if self.refit: - self._fit_best_model(X, y) - - return self