|
| 1 | +from sklearn.base import BaseEstimator |
| 2 | +from sklearn.base import RegressorMixin, ClassifierMixin |
| 3 | + |
| 4 | +ThundergbmBase = BaseEstimator |
| 5 | +ThundergbmRegressorBase = RegressorMixin |
| 6 | +ThundergbmClassifierBase = ClassifierMixin |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import scipy.sparse as sp |
| 10 | + |
| 11 | +from sklearn.utils import check_X_y |
| 12 | + |
| 13 | +from ctypes import * |
| 14 | +from os import path |
| 15 | +from sys import platform |
| 16 | + |
| 17 | +dirname = path.dirname(path.abspath(__file__)) |
| 18 | + |
| 19 | +if platform == "linux" or platform == "linux2": |
| 20 | + shared_library_name = "libthundergbm.so" |
| 21 | +elif platform == "win32": |
| 22 | + shared_library_name = "thundergbm.dll" |
| 23 | +elif platform == "darwin": |
| 24 | + shared_library_name = "libthundergbm.dylib" |
| 25 | +else: |
| 26 | + raise EnvironmentError("OS not supported!") |
| 27 | + |
| 28 | +if path.exists(path.abspath(path.join(dirname, shared_library_name))): |
| 29 | + lib_path = path.abspath(path.join(dirname, shared_library_name)) |
| 30 | +else: |
| 31 | + if platform == "linux" or platform == "linux2": |
| 32 | + lib_path = path.join(dirname, shared_library_name) |
| 33 | + |
| 34 | +if path.exists(lib_path): |
| 35 | + thundergbm = CDLL(lib_path) |
| 36 | +else: |
| 37 | + raise RuntimeError("Please build the library first!") |
| 38 | + |
| 39 | +OBJECTIVE_TYPE = ['reg:linear', 'reg:logistic', 'multi:softprob', 'multi:softmax', 'rank:pairwise', 'rank:ndcg'] |
| 40 | + |
| 41 | + |
| 42 | +class TGBMModel(ThundergbmBase, ThundergbmRegressorBase): |
| 43 | + def __init__(self, depth=6, num_round=40, |
| 44 | + n_device=1, min_child_weight=1.0, lambda_tgbm=1.0, gamma=1.0, max_num_bin=255, |
| 45 | + verbose=0, column_sampling_rate=1.0, bagging=0, |
| 46 | + n_parallel_trees=1, learning_rate=0.9, objective="reg:linear", |
| 47 | + num_class=1, tree_method="auto"): |
| 48 | + self.depth = depth |
| 49 | + self.n_trees = num_round |
| 50 | + self.n_device = n_device |
| 51 | + self.min_child_weight = min_child_weight |
| 52 | + self.lambda_tgbm = lambda_tgbm |
| 53 | + self.gamma = gamma |
| 54 | + self.max_num_bin = max_num_bin |
| 55 | + self.verbose = verbose |
| 56 | + self.column_sampling_rate = column_sampling_rate |
| 57 | + self.bagging = bagging |
| 58 | + self.n_parallel_trees = n_parallel_trees |
| 59 | + self.learning_rate = learning_rate |
| 60 | + self.objective = objective |
| 61 | + self.num_class = num_class |
| 62 | + self.path = path |
| 63 | + self.tree_method = tree_method |
| 64 | + self.model = None |
| 65 | + self.tree_per_iter = -1 |
| 66 | + self.group_label = None |
| 67 | + |
| 68 | + def fit(self, X, y): |
| 69 | + sparse = sp.isspmatrix(X) |
| 70 | + if sparse is False: |
| 71 | + X = sp.csr_matrix(X) |
| 72 | + X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr') |
| 73 | + |
| 74 | + fit = self._sparse_fit |
| 75 | + |
| 76 | + fit(X, y) |
| 77 | + return self |
| 78 | + |
| 79 | + def _sparse_fit(self, X, y): |
| 80 | + X.data = np.asarray(X.data, dtype=np.float64, order='C') |
| 81 | + X.sort_indices() |
| 82 | + |
| 83 | + data = (c_float * X.data.size)() |
| 84 | + data[:] = X.data |
| 85 | + indices = (c_int * X.indices.size)() |
| 86 | + indices[:] = X.indices |
| 87 | + indptr = (c_int * X.indptr.size)() |
| 88 | + indptr[:] = X.indptr |
| 89 | + label = (c_float * y.size)() |
| 90 | + label[:] = y |
| 91 | + # self.group_label |
| 92 | + group_label = (c_float * len(set(y)))() |
| 93 | + n_class = (c_int * 1)() |
| 94 | + n_class[0] = self.num_class |
| 95 | + tree_per_iter_ptr = (c_int * 1)() |
| 96 | + self.model = (c_long * 1)() |
| 97 | + # self._train_succeed = (c_int * 1)() |
| 98 | + thundergbm.sparse_train_scikit(X.shape[0], data, indptr, indices, label, self.depth, self.n_trees, |
| 99 | + self.n_device, c_float(self.min_child_weight), c_float(self.lambda_tgbm), |
| 100 | + c_float(self.gamma), |
| 101 | + self.max_num_bin, self.verbose, c_float(self.column_sampling_rate), self.bagging, |
| 102 | + self.n_parallel_trees, c_float(self.learning_rate), |
| 103 | + self.objective.encode('utf-8'), |
| 104 | + n_class, self.tree_method.encode('utf-8'), byref(self.model), tree_per_iter_ptr, |
| 105 | + group_label) |
| 106 | + self.num_class = n_class[0] |
| 107 | + self.tree_per_iter = tree_per_iter_ptr[0] |
| 108 | + self.group_label = [group_label[idx] for idx in range(len(set(y)))] |
| 109 | + if self.model is None: |
| 110 | + print("The model returned is empty!") |
| 111 | + exit() |
| 112 | + |
| 113 | + def predict(self, X): |
| 114 | + if self.model is None: |
| 115 | + print("Please train the model first or load model from file!") |
| 116 | + raise ValueError |
| 117 | + sparse = sp.isspmatrix(X) |
| 118 | + if sparse is False: |
| 119 | + X = sp.csr_matrix(X) |
| 120 | + X.data = np.asarray(X.data, dtype=np.float64, order='C') |
| 121 | + X.sort_indices() |
| 122 | + data = (c_float * X.data.size)() |
| 123 | + data[:] = X.data |
| 124 | + indices = (c_int * X.indices.size)() |
| 125 | + indices[:] = X.indices |
| 126 | + indptr = (c_int * X.indptr.size)() |
| 127 | + indptr[:] = X.indptr |
| 128 | + self.predict_label_ptr = (c_float * X.shape[0])() |
| 129 | + if self.group_label is not None: |
| 130 | + group_label = (c_float * len(self.group_label))() |
| 131 | + group_label[:] = self.group_label |
| 132 | + else: |
| 133 | + group_label = None |
| 134 | + thundergbm.sparse_predict_scikit( |
| 135 | + X.shape[0], |
| 136 | + data, |
| 137 | + indptr, |
| 138 | + indices, |
| 139 | + self.predict_label_ptr, |
| 140 | + byref(self.model), |
| 141 | + self.n_trees, |
| 142 | + self.tree_per_iter, |
| 143 | + self.objective.encode('utf-8'), |
| 144 | + self.num_class, |
| 145 | + c_float(self.learning_rate), |
| 146 | + group_label |
| 147 | + ) |
| 148 | + predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])] |
| 149 | + self.predict_label = np.asarray(predict_label) |
| 150 | + return self.predict_label |
| 151 | + |
| 152 | + def save_model(self, model_path): |
| 153 | + if self.model is None: |
| 154 | + print("Please train the model first or load model from file!") |
| 155 | + raise ValueError |
| 156 | + if self.group_label is not None: |
| 157 | + group_label = (c_float * len(self.group_label))() |
| 158 | + group_label[:] = self.group_label |
| 159 | + thundergbm.save( |
| 160 | + model_path.encode('utf-8'), |
| 161 | + self.objective.encode('utf-8'), |
| 162 | + c_float(self.learning_rate), |
| 163 | + self.num_class, |
| 164 | + self.n_trees, |
| 165 | + self.tree_per_iter, |
| 166 | + byref(self.model), |
| 167 | + group_label |
| 168 | + ) |
| 169 | + |
| 170 | + def load_model(self, model_path): |
| 171 | + self.model = (c_long * 1)() |
| 172 | + learning_rate = (c_float * 1)() |
| 173 | + n_class = (c_int * 1)() |
| 174 | + n_trees = (c_int * 1)() |
| 175 | + tree_per_iter = (c_int * 1)() |
| 176 | + thundergbm.load_model( |
| 177 | + model_path.encode('utf-8'), |
| 178 | + learning_rate, |
| 179 | + n_class, |
| 180 | + n_trees, |
| 181 | + tree_per_iter, |
| 182 | + byref(self.model) |
| 183 | + ) |
| 184 | + if self.model is None: |
| 185 | + raise ValueError("Model is None.") |
| 186 | + self.learning_rate = learning_rate[0] |
| 187 | + self.num_class = n_class[0] |
| 188 | + self.n_trees = n_trees[0] |
| 189 | + self.tree_per_iter = tree_per_iter[0] |
| 190 | + group_label = (c_float * self.num_class)() |
| 191 | + thundergbm.load_config( |
| 192 | + model_path.encode('utf-8'), |
| 193 | + group_label |
| 194 | + ) |
| 195 | + self.group_label = [group_label[idx] for idx in range(self.num_class)] |
0 commit comments