Skip to content

Commit ea0e32a

Browse files
committed
update readme and test
1 parent 69cb993 commit ea0e32a

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
1-
`gbt` is a library for gradient boosted trees with minimal coding required.
1+
`gbt` is a library for gradient boosted trees with minimal coding required. It is a thin wrapper around [`lightgbm`](https://lightgbm.readthedocs.io/en/v3.3.2/). Give it a `pandas.Dataframe`, `gbt.train()` takes care of feature transforms (e.g. scaling for numerical features, label encoding for categorical features) and metrics print outs.
22

33
Example usage:
44

55
```
66
import pandas as pd
77
import gbt
88
9-
df = pd.read_csv("my_data.csv")
10-
gbt.train(
9+
df = pd.DataFrame({
10+
"a": [1, 2, 3, 4, 5, 6, 7],
11+
"b": ["a", "b", "c", None, "e", "f", "g"],
12+
"c": [1, 0, 1, 1, 0, 0, 1],
13+
"some_other_column": [0, 0, None, None, None, 3, 3]
14+
})
15+
train(
1116
df,
12-
target_column=...,
17+
recipe="binary",
18+
label_column="c",
1319
val_size=0.2, # fraction of the validation split
14-
categorical_feature_columns=[],
15-
numerical_feature_columns=[],
20+
categorical_feature_columns=["b"],
21+
numerical_feature_columns=["a"],
1622
)
1723
```
24+
25+
Supported "recipes": mape, l2, l2_rf, binary, multiclass.

gbt/api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def train(
5656
df_test=None,
5757
recipe="l2",
5858
num_classes=None,
59+
label_column=None,
5960
categorical_feature_columns=None,
6061
numerical_feature_columns=None,
6162
preprocess_fn=None,
6263
sort_by_columns=None,
63-
label_column=None,
6464
add_categorical_stats=False,
6565
pretrain_size=0,
6666
val_size=0.1,
@@ -100,7 +100,9 @@ def train(
100100
)
101101
ds.df = df
102102
ds.preprocess()
103-
assert ds.features.shape[0] > 10
103+
if ds.features.shape[0] <= 10:
104+
import warnings
105+
warnings.warn(f"Too few samples: {ds.features.shape[0]}. Training may not converge.")
104106
train_ds, val_ds = ds.split(
105107
pretrain_size=pretrain_size, val_size=val_size, shuffle=False
106108
)
@@ -163,6 +165,8 @@ def train(
163165
"lambda_l2": 0.001,
164166
"num_class": num_classes,
165167
}
168+
else:
169+
raise ValueError(f"Unknown recipe: {recipe}. Supported: mape, l2, l2_rf, binary, multiclass")
166170

167171
print(parameters)
168172
model = LightGBMModel(parameters=parameters, rounds=100)

tests/test_gbt.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from os import path
22

3+
import pandas as pd
34
from gbt.api import train
45

6+
57
csd = path.dirname(path.realpath(__file__))
68

79

@@ -13,3 +15,20 @@ def test_model_can_train():
1315
sort_by_columns=None,
1416
label_column="label",
1517
)
18+
19+
20+
def test_the_readme_example():
21+
df = pd.DataFrame({
22+
"a": [1, 2, 3, 4, 5, 6, 7],
23+
"b": ["a", "b", "c", None, "e", "f", "g"],
24+
"c": [1, 0, 1, 1, 0, 0, 1],
25+
"some_other_column": [0, 0, None, None, None, 3, 3]
26+
})
27+
train(
28+
df,
29+
recipe="binary",
30+
label_column="c",
31+
val_size=0.2, # fraction of the validation split
32+
categorical_feature_columns=["b"],
33+
numerical_feature_columns=["a"],
34+
)

0 commit comments

Comments
 (0)