From 8701fa14b6dd35b4f16b995fe66ad177376e7e0b Mon Sep 17 00:00:00 2001 From: ladezai <83292183+ladezai@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:57:35 +0100 Subject: [PATCH] :scroll: Improve linfa_ensemble documentation --- algorithms/linfa-ensemble/src/algorithm.rs | 55 ++++++++++++++++++++ algorithms/linfa-ensemble/src/hyperparams.rs | 15 +++++- algorithms/linfa-ensemble/src/lib.rs | 14 +++-- 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/algorithms/linfa-ensemble/src/algorithm.rs b/algorithms/linfa-ensemble/src/algorithm.rs index 407bb190f..e2cf531b8 100644 --- a/algorithms/linfa-ensemble/src/algorithm.rs +++ b/algorithms/linfa-ensemble/src/algorithm.rs @@ -10,8 +10,63 @@ use ndarray::{Array2, Axis, Zip}; use rand::Rng; use std::{cmp::Eq, collections::HashMap, hash::Hash}; +/// A fitted ensemble of [Decision Trees](DecisionTree) trained on a random subset of features. +/// +/// Check out [EnsembleLearner] documentation for more information regarding [RandomForest] interface. pub type RandomForest = EnsembleLearner>; +/// A fitted ensemble of learners for classification. +/// +/// ## Structure +/// +/// An Ensemble Learner is composed of a collection of fitted models of type `M`. +/// +/// ## Fitting Algorithm +/// +/// Given a [DatasetBase](DatasetBase) denoted as `D`, +/// 1. Create as many distinct bootstrapped subset of the original dataset `D` as number of +/// distinct model to fit. +/// 2. Fit each distinct model on a distinct bootstrapped subset of `D`. +/// +/// Note that the subset size, as well as the subset of feature to use in each training subset can +/// be specified in the [parameters](crate::EnsembleLearnerParams). +/// +/// ## Prediction Algorithm +/// +/// The prediction result is the result of majority voting across the fitted learners. +/// +/// ## Example +/// +/// This example shows how to train a bagging model using 100 decision trees, +/// each trained on 70% of the training data (bootstrap sampling). +/// ```no_run +/// use linfa::prelude::{Fit, Predict}; +/// use linfa_ensemble::EnsembleLearnerParams; +/// use linfa_trees::DecisionTree; +/// use ndarray_rand::rand::SeedableRng; +/// use rand::rngs::SmallRng; +/// +/// // Load Iris dataset +/// let mut rng = SmallRng::seed_from_u64(42); +/// let (train, test) = linfa_datasets::iris() +/// .shuffle(&mut rng) +/// .split_with_ratio(0.8); +/// +/// // Train the model on the iris dataset +/// let bagging_model = EnsembleLearnerParams::new(DecisionTree::params()) +/// .ensemble_size(100) // Number of Decision Tree to fit +/// .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap +/// .fit(&train) +/// .unwrap(); +/// +/// // Make predictions on the test set +/// let predictions = bagging_model.predict(&test); +/// ``` +/// +/// ## References +/// +/// * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html) +/// * [An Introduction to Statistical Learning](https://www.statlearning.com/) pub struct EnsembleLearner { pub models: Vec, pub model_features: Vec>, diff --git a/algorithms/linfa-ensemble/src/hyperparams.rs b/algorithms/linfa-ensemble/src/hyperparams.rs index 6652180fa..df457e8d6 100644 --- a/algorithms/linfa-ensemble/src/hyperparams.rs +++ b/algorithms/linfa-ensemble/src/hyperparams.rs @@ -6,22 +6,26 @@ use linfa_trees::DecisionTreeParams; use rand::rngs::ThreadRng; use rand::Rng; +/// The set of valid hyper-parameters that can be specified for the fitting procedure of the +/// [Ensemble Learner](crate::EnsembleLearner). #[derive(Clone, Copy, Debug, PartialEq)] pub struct EnsembleLearnerValidParams { /// The number of models in the ensemble pub ensemble_size: usize, /// The proportion of the total number of training samples that should be given to each model for training pub bootstrap_proportion: f64, - /// The proportion of the total number of training feature that should be given to each model for training + /// The proportion of the total number of training features that should be given to each model for training pub feature_proportion: f64, /// The model parameters for the base model pub model_params: P, pub rng: R, } +/// A helper struct for building a set of [Ensemble Learner](crate::EnsembleLearner) hyper-parameters. #[derive(Clone, Copy, Debug, PartialEq)] pub struct EnsembleLearnerParams(EnsembleLearnerValidParams); +/// A helper struct for building a set of [Random Forest](crate::RandomForest) hyper-parameters. pub type RandomForestParams = EnsembleLearnerParams, R>; impl

EnsembleLearnerParams { @@ -41,16 +45,25 @@ impl EnsembleLearnerParams { }) } + /// Specifies the number of models to fit in the ensemble. pub fn ensemble_size(mut self, size: usize) -> Self { self.0.ensemble_size = size; self } + /// Sets the proportion of the total number of training samples that should be given to each model for training + /// + /// Note that the `proportion` should be in the interval (0, 1] in order to pass the + /// parameter validation check. pub fn bootstrap_proportion(mut self, proportion: f64) -> Self { self.0.bootstrap_proportion = proportion; self } + /// Sets the proportion of the total number of training features that should be given to each model for training + /// + /// Note that the `proportion` should be in the interval (0, 1] in order to pass the + /// parameter validation check. pub fn feature_proportion(mut self, proportion: f64) -> Self { self.0.feature_proportion = proportion; self diff --git a/algorithms/linfa-ensemble/src/lib.rs b/algorithms/linfa-ensemble/src/lib.rs index 67d59da35..26a94f616 100644 --- a/algorithms/linfa-ensemble/src/lib.rs +++ b/algorithms/linfa-ensemble/src/lib.rs @@ -3,14 +3,18 @@ //! Ensemble methods combine the predictions of several base estimators built with a given //! learning algorithm in order to improve generalizability / robustness over a single estimator. //! +//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as +//! * [Boostrap Aggregation](EnsembleLearner) +//! * [Random Forest](RandomForest) +//! //! ## Bootstrap Aggregation (aka Bagging) //! //! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of -//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset. +//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset. //! //! ## Random Forest //! -//! A special case of Bootstrap Aggregation using decision trees (see `linfa-trees`) with random feature +//! A special case of Bootstrap Aggregation using decision trees (see [`linfa-trees`](linfa_trees)) with random feature //! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being //! the number of available features. //! @@ -48,7 +52,7 @@ //! let predictions = bagging_model.predict(&test); //! ``` //! -//! This example shows how to train a Random Forest model using 100 decision trees, +//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees, //! each trained on 70% of the training data (bootstrap sampling) and using only //! 30% of the available features. //! @@ -66,7 +70,7 @@ //! .split_with_ratio(0.8); //! //! // Train the model on the iris dataset -//! let bagging_model = RandomForestParams::new(DecisionTree::params()) +//! let random_forest = RandomForestParams::new(DecisionTree::params()) //! .ensemble_size(100) // Number of Decision Tree to fit //! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap //! .feature_proportion(0.3) // Select only 30% of the feature @@ -74,7 +78,7 @@ //! .unwrap(); //! //! // Make predictions on the test set -//! let predictions = bagging_model.predict(&test); +//! let predictions = random_forest.predict(&test); //! ``` mod algorithm;