Skip to content

Commit 791ac52

Browse files
committed
improve python interface
1 parent c398d03 commit 791ac52

File tree

5 files changed

+218
-221
lines changed

5 files changed

+218
-221
lines changed

python/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cd python && python setup.py install
99
```
1010
Or you can install via pip
1111
```bash
12-
pip3 install -U thundergbm
12+
pip3 install thundergbm
1313
```
1414
* After you have successfully installed ThunderGBM, you can import TGBMModel:
1515
```python

python/setup.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,31 @@
33
from shutil import copyfile
44
from sys import platform
55

6-
76
dirname = path.dirname(path.abspath(__file__))
87

98
if platform == "linux" or platform == "linux2":
10-
lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.so'))
9+
lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.so'))
1110
elif platform == "win32":
12-
lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/thundergbm.dll'))
11+
lib_path = path.abspath(path.join(dirname, '../build/bin/Debug/thundergbm.dll'))
1312
elif platform == "darwin":
14-
lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.dylib'))
15-
else :
16-
print ("OS not supported!")
17-
exit()
13+
lib_path = path.abspath(path.join(dirname, '../build/lib/libthundergbm.dylib'))
14+
else:
15+
print("OS not supported!")
16+
exit()
1817
if not path.exists(path.join(dirname, "thundergbm", path.basename(lib_path))):
19-
copyfile(lib_path, path.join(dirname, "thundergbm", path.basename(lib_path)))
18+
copyfile(lib_path, path.join(dirname, "thundergbm", path.basename(lib_path)))
2019
setuptools.setup(name="thundergbm",
21-
version="0.0.7",
22-
packages=["thundergbm"],
23-
package_dir={"python": "thundergbm"},
24-
description="A Fast GBM Library on GPUs and CPUs",
25-
long_description="""The mission of ThunderGBM is to help users easily and efficiently apply GBDTs and Random Forests to solve problems. ThunderGBM exploits GPUs and multi-core CPUs to achieve high efficiency""",
26-
long_description_content_type="text/plain",
27-
url="https://github.com/zeyiwen/thundergbm",
28-
package_data = {"thundergbm": [path.basename(lib_path)]},
29-
install_requires=['numpy','scipy','scikit-learn'],
30-
classifiers=[
31-
"Programming Language :: Python :: 3",
32-
"License :: OSI Approved :: Apache Software License",
33-
],
34-
python_requires=">=3"
35-
)
20+
version="0.3.2",
21+
packages=["thundergbm"],
22+
package_dir={"python": "thundergbm"},
23+
description="A Fast GBM Library on GPUs and CPUs",
24+
long_description="""The mission of ThunderGBM is to help users easily and efficiently apply GBDTs and Random Forests to solve problems. ThunderGBM exploits GPUs and multi-core CPUs to achieve high efficiency""",
25+
long_description_content_type="text/plain",
26+
url="https://github.com/zeyiwen/thundergbm",
27+
package_data={"thundergbm": [path.basename(lib_path)]},
28+
install_requires=['numpy', 'scipy', 'scikit-learn'],
29+
classifiers=[
30+
"Programming Language :: Python :: 3",
31+
"License :: OSI Approved :: Apache Software License",
32+
],
33+
)

python/thundergbm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
* Description :
1010
"""
1111
name = "thundergbm"
12-
from .thundergbm_scikit import *
12+
from .thundergbm import *

python/thundergbm/thundergbm.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from sklearn.base import BaseEstimator
2+
from sklearn.base import RegressorMixin, ClassifierMixin
3+
4+
ThundergbmBase = BaseEstimator
5+
ThundergbmRegressorBase = RegressorMixin
6+
ThundergbmClassifierBase = ClassifierMixin
7+
8+
import numpy as np
9+
import scipy.sparse as sp
10+
11+
from sklearn.utils import check_X_y
12+
13+
from ctypes import *
14+
from os import path
15+
from sys import platform
16+
17+
dirname = path.dirname(path.abspath(__file__))
18+
19+
if platform == "linux" or platform == "linux2":
20+
shared_library_name = "libthundergbm.so"
21+
elif platform == "win32":
22+
shared_library_name = "thundergbm.dll"
23+
elif platform == "darwin":
24+
shared_library_name = "libthundergbm.dylib"
25+
else:
26+
raise EnvironmentError("OS not supported!")
27+
28+
if path.exists(path.abspath(path.join(dirname, shared_library_name))):
29+
lib_path = path.abspath(path.join(dirname, shared_library_name))
30+
else:
31+
if platform == "linux" or platform == "linux2":
32+
lib_path = path.join(dirname, shared_library_name)
33+
34+
if path.exists(lib_path):
35+
thundergbm = CDLL(lib_path)
36+
else:
37+
raise RuntimeError("Please build the library first!")
38+
39+
OBJECTIVE_TYPE = ['reg:linear', 'reg:logistic', 'multi:softprob', 'multi:softmax', 'rank:pairwise', 'rank:ndcg']
40+
41+
42+
class TGBMModel(ThundergbmBase, ThundergbmRegressorBase):
43+
def __init__(self, depth=6, num_round=40,
44+
n_device=1, min_child_weight=1.0, lambda_tgbm=1.0, gamma=1.0, max_num_bin=255,
45+
verbose=0, column_sampling_rate=1.0, bagging=0,
46+
n_parallel_trees=1, learning_rate=0.9, objective="reg:linear",
47+
num_class=1, tree_method="auto"):
48+
self.depth = depth
49+
self.n_trees = num_round
50+
self.n_device = n_device
51+
self.min_child_weight = min_child_weight
52+
self.lambda_tgbm = lambda_tgbm
53+
self.gamma = gamma
54+
self.max_num_bin = max_num_bin
55+
self.verbose = verbose
56+
self.column_sampling_rate = column_sampling_rate
57+
self.bagging = bagging
58+
self.n_parallel_trees = n_parallel_trees
59+
self.learning_rate = learning_rate
60+
self.objective = objective
61+
self.num_class = num_class
62+
self.path = path
63+
self.tree_method = tree_method
64+
self.model = None
65+
self.tree_per_iter = -1
66+
self.group_label = None
67+
68+
def fit(self, X, y):
69+
sparse = sp.isspmatrix(X)
70+
if sparse is False:
71+
X = sp.csr_matrix(X)
72+
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
73+
74+
fit = self._sparse_fit
75+
76+
fit(X, y)
77+
return self
78+
79+
def _sparse_fit(self, X, y):
80+
X.data = np.asarray(X.data, dtype=np.float64, order='C')
81+
X.sort_indices()
82+
83+
data = (c_float * X.data.size)()
84+
data[:] = X.data
85+
indices = (c_int * X.indices.size)()
86+
indices[:] = X.indices
87+
indptr = (c_int * X.indptr.size)()
88+
indptr[:] = X.indptr
89+
label = (c_float * y.size)()
90+
label[:] = y
91+
# self.group_label
92+
group_label = (c_float * len(set(y)))()
93+
n_class = (c_int * 1)()
94+
n_class[0] = self.num_class
95+
tree_per_iter_ptr = (c_int * 1)()
96+
self.model = (c_long * 1)()
97+
# self._train_succeed = (c_int * 1)()
98+
thundergbm.sparse_train_scikit(X.shape[0], data, indptr, indices, label, self.depth, self.n_trees,
99+
self.n_device, c_float(self.min_child_weight), c_float(self.lambda_tgbm),
100+
c_float(self.gamma),
101+
self.max_num_bin, self.verbose, c_float(self.column_sampling_rate), self.bagging,
102+
self.n_parallel_trees, c_float(self.learning_rate),
103+
self.objective.encode('utf-8'),
104+
n_class, self.tree_method.encode('utf-8'), byref(self.model), tree_per_iter_ptr,
105+
group_label)
106+
self.num_class = n_class[0]
107+
self.tree_per_iter = tree_per_iter_ptr[0]
108+
self.group_label = [group_label[idx] for idx in range(len(set(y)))]
109+
if self.model is None:
110+
print("The model returned is empty!")
111+
exit()
112+
113+
def predict(self, X):
114+
if self.model is None:
115+
print("Please train the model first or load model from file!")
116+
raise ValueError
117+
sparse = sp.isspmatrix(X)
118+
if sparse is False:
119+
X = sp.csr_matrix(X)
120+
X.data = np.asarray(X.data, dtype=np.float64, order='C')
121+
X.sort_indices()
122+
data = (c_float * X.data.size)()
123+
data[:] = X.data
124+
indices = (c_int * X.indices.size)()
125+
indices[:] = X.indices
126+
indptr = (c_int * X.indptr.size)()
127+
indptr[:] = X.indptr
128+
self.predict_label_ptr = (c_float * X.shape[0])()
129+
if self.group_label is not None:
130+
group_label = (c_float * len(self.group_label))()
131+
group_label[:] = self.group_label
132+
else:
133+
group_label = None
134+
thundergbm.sparse_predict_scikit(
135+
X.shape[0],
136+
data,
137+
indptr,
138+
indices,
139+
self.predict_label_ptr,
140+
byref(self.model),
141+
self.n_trees,
142+
self.tree_per_iter,
143+
self.objective.encode('utf-8'),
144+
self.num_class,
145+
c_float(self.learning_rate),
146+
group_label
147+
)
148+
predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])]
149+
self.predict_label = np.asarray(predict_label)
150+
return self.predict_label
151+
152+
def save_model(self, model_path):
153+
if self.model is None:
154+
print("Please train the model first or load model from file!")
155+
raise ValueError
156+
if self.group_label is not None:
157+
group_label = (c_float * len(self.group_label))()
158+
group_label[:] = self.group_label
159+
thundergbm.save(
160+
model_path.encode('utf-8'),
161+
self.objective.encode('utf-8'),
162+
c_float(self.learning_rate),
163+
self.num_class,
164+
self.n_trees,
165+
self.tree_per_iter,
166+
byref(self.model),
167+
group_label
168+
)
169+
170+
def load_model(self, model_path):
171+
self.model = (c_long * 1)()
172+
learning_rate = (c_float * 1)()
173+
n_class = (c_int * 1)()
174+
n_trees = (c_int * 1)()
175+
tree_per_iter = (c_int * 1)()
176+
thundergbm.load_model(
177+
model_path.encode('utf-8'),
178+
learning_rate,
179+
n_class,
180+
n_trees,
181+
tree_per_iter,
182+
byref(self.model)
183+
)
184+
if self.model is None:
185+
raise ValueError("Model is None.")
186+
self.learning_rate = learning_rate[0]
187+
self.num_class = n_class[0]
188+
self.n_trees = n_trees[0]
189+
self.tree_per_iter = tree_per_iter[0]
190+
group_label = (c_float * self.num_class)()
191+
thundergbm.load_config(
192+
model_path.encode('utf-8'),
193+
group_label
194+
)
195+
self.group_label = [group_label[idx] for idx in range(self.num_class)]

0 commit comments

Comments
 (0)