Skip to content

[Feature] Add Random Forest Classifier to linfa-trees #389

@maxprogrammer007

Description

@maxprogrammer007

📝 Description

I would like to contribute a new module to the linfa-trees crate that implements the Random Forest algorithm for classification tasks. This will expand linfa-trees from single decision trees into ensemble learning, aligning closely with scikit-learn's functionality in Python.


🚀 Motivation

Random Forests are a powerful ensemble learning method used widely in classification tasks. They provide:

  • Robustness to overfitting

  • Better generalization than single trees

  • Feature importance estimates

Currently, linfa-trees provides support for single decision trees. By adding Random Forests, we unlock ensemble learning for the Rust ML ecosystem.


📐 Proposed Design

🔹 New Module

A new file will be added:

bash
CopyEdit
linfa-trees/src/decision_trees/random_forest.rs

This will include:

  • RandomForestClassifier<F: Float>

  • RandomForestParams<F> (unchecked)

  • RandomForestValidParams<F> (checked)

🔹 Trait Implementations

I will implement the following traits according to linfa conventions:

  • ParamGuard for parameter validation

  • Fit to train the forest using bootstrapped data and random feature subsetting

  • PredictInplace and Predict to perform inference via majority voting

🔹 Example

An example will be added in:

bash
CopyEdit
linfa-trees/examples/iris_random_forest.rs

Using the Iris dataset from linfa-datasets.

🔹 Benchmark (Optional)

If approved, I can also add a benchmark using Criterion:

bash
CopyEdit
linfa-trees/benches/random_forest.rs

📁 File Integration Plan

  • src/lib.rs: Re-export random_forest::*

  • src/decision_trees/mod.rs: pub mod random_forest;

  • README.md: Update with a section on Random Forests and example usage

  • examples/iris_random_forest.rs: Demonstrates training and evaluation


📦 API Preview

rust
CopyEdit
let model = RandomForest::params() .n_trees(100) .feature_subsample(0.8) .max_depth(Some(10)) .fit(&dataset)?;

let predictions = model.predict(&dataset);
let acc = predictions.confusion_matrix(&dataset)?.accuracy();


✅ Conformity with CONTRIBUTING.md

  • Uses Float trait for f32/f64 compatibility

  • Follows the ParamsValidParams validation pattern

  • Implements Fit, Predict, and PredictInplace using Dataset

  • Optional serde support via feature flag

  • Will include unit tests and optionally benchmarks



🙋‍♂️ Request

Please let me know if you're open to this contribution. I’d be happy to align with maintainers on:

  • Feature scope (classifier first, regressor later?)

  • Benchmarking standards

  • Integration strategy (e.g., reuse of DecisionTree)

Looking forward to your guidance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions