Skip to content

Commit cc2c6cb

Browse files
authored
Merge branch 'master' into release_1.4.2
2 parents 4b1e533 + 5e84292 commit cc2c6cb

File tree

3 files changed

+91
-18
lines changed

3 files changed

+91
-18
lines changed

README.md

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Basic usage example:
1212
```rust
1313
extern crate xgboost;
1414

15-
use xgboost::{parameters, dmatrix::DMatrix, booster::Booster};
15+
use xgboost::{parameters, DMatrix, Booster};
1616

1717
fn main() {
1818
// training matrix with 5 training examples and 3 features
@@ -37,14 +37,37 @@ fn main() {
3737
let mut dtest = DMatrix::from_dense(x_test, num_rows).unwrap();
3838
dtest.set_labels(y_test).unwrap();
3939

40-
// build overall training parameters
41-
let params = parameters::ParametersBuilder::default().build().unwrap();
40+
// configure objectives, metrics, etc.
41+
let learning_params = parameters::learning::LearningTaskParametersBuilder::default()
42+
.objective(parameters::learning::Objective::BinaryLogistic)
43+
.build().unwrap();
44+
45+
// configure the tree-based learning model's parameters
46+
let tree_params = parameters::tree::TreeBoosterParametersBuilder::default()
47+
.max_depth(2)
48+
.eta(1.0)
49+
.build().unwrap();
50+
51+
// overall configuration for Booster
52+
let booster_params = parameters::BoosterParametersBuilder::default()
53+
.booster_type(parameters::BoosterType::Tree(tree_params))
54+
.learning_params(learning_params)
55+
.verbose(true)
56+
.build().unwrap();
4257

4358
// specify datasets to evaluate against during training
4459
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];
4560

61+
// overall configuration for training/evaluation
62+
let params = parameters::TrainingParametersBuilder::default()
63+
.dtrain(&dtrain) // dataset to train with
64+
.boost_rounds(2) // number of training iterations
65+
.booster_params(booster_params) // model parameters
66+
.evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
67+
.build().unwrap();
68+
4669
// train model, and print evaluation data
47-
let bst = Booster::train(&params, &dtrain, 3, evaluation_sets).unwrap();
70+
let bst = Booster::train(&params).unwrap();
4871

4972
println!("{:?}", bst.predict(&dtest).unwrap());
5073
}

src/parameters/booster.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//!
1010
//! let tree_params = TreeBoosterParametersBuilder::default()
1111
//! .eta(0.2)
12-
//! .gamma(3)
12+
//! .gamma(3.0)
1313
//! .subsample(0.75)
1414
//! .build()
1515
//! .unwrap();

src/parameters/tree.rs

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,32 @@ impl Default for TreeMethod {
5353
fn default() -> Self { TreeMethod::Auto }
5454
}
5555

56+
impl From<String> for TreeMethod
57+
{
58+
fn from(s: String) -> Self
59+
{
60+
use std::borrow::Borrow;
61+
Self::from(s.borrow())
62+
}
63+
}
64+
65+
impl<'a> From<&'a str> for TreeMethod
66+
{
67+
fn from(s: &'a str) -> Self
68+
{
69+
match s
70+
{
71+
"auto" => TreeMethod::Auto,
72+
"exact" => TreeMethod::Exact,
73+
"approx" => TreeMethod::Approx,
74+
"hist" => TreeMethod::Hist,
75+
"gpu_exact" => TreeMethod::GpuExact,
76+
"gpu_hist" => TreeMethod::GpuHist,
77+
_ => panic!("no known tree_method for {}", s)
78+
}
79+
}
80+
}
81+
5682
/// Provides a modular way to construct and to modify the trees. This is an advanced parameter that is usually set
5783
/// automatically, depending on some other parameters. However, it could be also set explicitly by a user.
5884
#[derive(Clone)]
@@ -191,7 +217,7 @@ pub struct TreeBoosterParameters {
191217
///
192218
/// * range: [0,∞]
193219
/// * default: 0
194-
gamma: u32,
220+
gamma: f32,
195221

196222
/// Maximum depth of a tree, increase this value will make the model more complex / likely to be overfitting.
197223
/// 0 indicates no limit, limit is required for depth-wise grow policy.
@@ -208,7 +234,7 @@ pub struct TreeBoosterParameters {
208234
///
209235
/// * range: [0,∞]
210236
/// * default: 1
211-
min_child_weight: u32,
237+
min_child_weight: f32,
212238

213239
/// Maximum delta step we allow each tree’s weight estimation to be.
214240
/// If the value is set to 0, it means there is no constraint. If it is set to a positive value,
@@ -218,7 +244,7 @@ pub struct TreeBoosterParameters {
218244
///
219245
/// * range: [0,∞]
220246
/// * default: 0
221-
max_delta_step: u32,
247+
max_delta_step: f32,
222248

223249
/// Subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly collected half
224250
/// of the data instances to grow trees and this will prevent overfitting.
@@ -239,15 +265,21 @@ pub struct TreeBoosterParameters {
239265
/// * default: 1.0
240266
colsample_bylevel: f32,
241267

268+
/// Subsample ratio of columns for each node.
269+
///
270+
/// * range: (0.0, 1.0]
271+
/// * default: 1.0
272+
colsample_bynode: f32,
273+
242274
/// L2 regularization term on weights, increase this value will make model more conservative.
243275
///
244276
/// * default: 1
245-
lambda: u32,
277+
lambda: f32,
246278

247279
/// L1 regularization term on weights, increase this value will make model more conservative.
248280
///
249281
/// * default: 0
250-
alpha: u32,
282+
alpha: f32,
251283

252284
/// The tree construction algorithm used in XGBoost.
253285
#[builder(default = "TreeMethod::default()")]
@@ -270,7 +302,7 @@ pub struct TreeBoosterParameters {
270302

271303
/// Sequence of tree updaters to run, providing a modular way to construct and to modify the trees.
272304
///
273-
/// * default: [TreeUpdater::GrowColMaker, TreeUpdater::Prune]
305+
/// * default: vec![]
274306
updater: Vec<TreeUpdater>,
275307

276308
/// This is a parameter of the ‘refresh’ updater plugin. When this flag is true, tree leafs as well as tree nodes'
@@ -300,6 +332,11 @@ pub struct TreeBoosterParameters {
300332
/// * default: 256
301333
max_bin: u32,
302334

335+
/// Number of trees to train in parallel for boosted random forest.
336+
///
337+
/// * default: 1
338+
num_parallel_tree: u32,
339+
303340
/// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
304341
///
305342
/// * default: [`Predictor::Cpu`](enum.Predictor.html#variant.Cpu)
@@ -310,24 +347,26 @@ impl Default for TreeBoosterParameters {
310347
fn default() -> Self {
311348
TreeBoosterParameters {
312349
eta: 0.3,
313-
gamma: 0,
350+
gamma: 0.0,
314351
max_depth: 6,
315-
min_child_weight: 1,
316-
max_delta_step: 0,
352+
min_child_weight: 1.0,
353+
max_delta_step: 0.0,
317354
subsample: 1.0,
318355
colsample_bytree: 1.0,
319356
colsample_bylevel: 1.0,
320-
lambda: 1,
321-
alpha: 0,
357+
colsample_bynode: 1.0,
358+
lambda: 1.0,
359+
alpha: 0.0,
322360
tree_method: TreeMethod::default(),
323361
sketch_eps: 0.03,
324362
scale_pos_weight: 1.0,
325-
updater: vec![TreeUpdater::GrowColMaker, TreeUpdater::Prune],
363+
updater: Vec::new(),
326364
refresh_leaf: true,
327365
process_type: ProcessType::default(),
328366
grow_policy: GrowPolicy::default(),
329367
max_leaves: 0,
330368
max_bin: 256,
369+
num_parallel_tree: 1,
331370
predictor: Predictor::default(),
332371
}
333372
}
@@ -347,19 +386,29 @@ impl TreeBoosterParameters {
347386
v.push(("subsample".to_owned(), self.subsample.to_string()));
348387
v.push(("colsample_bytree".to_owned(), self.colsample_bytree.to_string()));
349388
v.push(("colsample_bylevel".to_owned(), self.colsample_bylevel.to_string()));
389+
v.push(("colsample_bynode".to_owned(), self.colsample_bynode.to_string()));
350390
v.push(("lambda".to_owned(), self.lambda.to_string()));
351391
v.push(("alpha".to_owned(), self.alpha.to_string()));
352392
v.push(("tree_method".to_owned(), self.tree_method.to_string()));
353393
v.push(("sketch_eps".to_owned(), self.sketch_eps.to_string()));
354394
v.push(("scale_pos_weight".to_owned(), self.scale_pos_weight.to_string()));
355-
v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
356395
v.push(("refresh_leaf".to_owned(), (self.refresh_leaf as u8).to_string()));
357396
v.push(("process_type".to_owned(), self.process_type.to_string()));
358397
v.push(("grow_policy".to_owned(), self.grow_policy.to_string()));
359398
v.push(("max_leaves".to_owned(), self.max_leaves.to_string()));
360399
v.push(("max_bin".to_owned(), self.max_bin.to_string()));
400+
v.push(("num_parallel_tree".to_owned(), self.num_parallel_tree.to_string()));
361401
v.push(("predictor".to_owned(), self.predictor.to_string()));
362402

403+
// Don't pass anything to XGBoost if the user didn't specify anything.
404+
// This allows XGBoost to figure it out on it's own, and suppresses the
405+
// warning message during training.
406+
// See: https://github.com/davechallis/rust-xgboost/issues/7
407+
if self.updater.len() != 0
408+
{
409+
v.push(("updater".to_owned(), self.updater.iter().map(|u| u.to_string()).collect::<Vec<String>>().join(",")));
410+
}
411+
363412
v
364413
}
365414
}
@@ -370,6 +419,7 @@ impl TreeBoosterParametersBuilder {
370419
Interval::new_open_closed(0.0, 1.0).validate(&self.subsample, "subsample")?;
371420
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bytree, "colsample_bytree")?;
372421
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bylevel, "colsample_bylevel")?;
422+
Interval::new_open_closed(0.0, 1.0).validate(&self.colsample_bynode, "colsample_bynode")?;
373423
Interval::new_open_open(0.0, 1.0).validate(&self.sketch_eps, "sketch_eps")?;
374424
Ok(())
375425
}

0 commit comments

Comments
 (0)