-
-
Notifications
You must be signed in to change notification settings - Fork 282
Description
📝 Description
I would like to contribute a new module to the linfa-trees
crate that implements the Random Forest algorithm for classification tasks. This will expand linfa-trees
from single decision trees into ensemble learning, aligning closely with scikit-learn's functionality in Python.
🚀 Motivation
Random Forests are a powerful ensemble learning method used widely in classification tasks. They provide:
-
Robustness to overfitting
-
Better generalization than single trees
-
Feature importance estimates
Currently, linfa-trees
provides support for single decision trees. By adding Random Forests, we unlock ensemble learning for the Rust ML ecosystem.
📐 Proposed Design
🔹 New Module
A new file will be added:
bashCopyEditlinfa-trees/src/decision_trees/random_forest.rs
This will include:
-
RandomForestClassifier<F: Float>
-
RandomForestParams<F>
(unchecked) -
RandomForestValidParams<F>
(checked)
🔹 Trait Implementations
I will implement the following traits according to linfa
conventions:
-
ParamGuard
for parameter validation -
Fit
to train the forest using bootstrapped data and random feature subsetting -
PredictInplace
andPredict
to perform inference via majority voting
🔹 Example
An example will be added in:
bashCopyEditlinfa-trees/examples/iris_random_forest.rs
Using the Iris dataset from linfa-datasets
.
🔹 Benchmark (Optional)
If approved, I can also add a benchmark using Criterion:
bashCopyEditlinfa-trees/benches/random_forest.rs
📁 File Integration Plan
-
src/lib.rs
: Re-exportrandom_forest::*
-
src/decision_trees/mod.rs
:pub mod random_forest;
-
README.md
: Update with a section on Random Forests and example usage -
examples/iris_random_forest.rs
: Demonstrates training and evaluation
📦 API Preview
rustCopyEditlet model = RandomForest::params() .n_trees(100) .feature_subsample(0.8) .max_depth(Some(10)) .fit(&dataset)?;
let predictions = model.predict(&dataset);
let acc = predictions.confusion_matrix(&dataset)?.accuracy();
✅ Conformity with CONTRIBUTING.md
-
Uses
Float
trait forf32
/f64
compatibility -
Follows the
Params
→ValidParams
validation pattern -
Implements
Fit
,Predict
, andPredictInplace
usingDataset
-
Optional
serde
support via feature flag -
Will include unit tests and optionally benchmarks
🙋♂️ Request
Please let me know if you're open to this contribution. I’d be happy to align with maintainers on:
-
Feature scope (classifier first, regressor later?)
-
Benchmarking standards
-
Integration strategy (e.g., reuse of
DecisionTree
)
Looking forward to your guidance!