Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions algorithms/linfa-ensemble/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,63 @@ use ndarray::{Array2, Axis, Zip};
use rand::Rng;
use std::{cmp::Eq, collections::HashMap, hash::Hash};

/// A fitted ensemble of [Decision Trees](DecisionTree) trained on a random subset of features.
///
/// Check out [EnsembleLearner] documentation for more information regarding [RandomForest] interface.
pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;

/// A fitted ensemble of learners for classification.
///
/// ## Structure
///
/// An Ensemble Learner is composed of a collection of fitted models of type `M`.
///
/// ## Fitting Algorithm
///
/// Given a [DatasetBase](DatasetBase) denoted as `D`,
/// 1. Create as many distinct bootstrapped subset of the original dataset `D` as number of
/// distinct model to fit.
/// 2. Fit each distinct model on a distinct bootstrapped subset of `D`.
///
/// Note that the subset size, as well as the subset of feature to use in each training subset can
/// be specified in the [parameters](crate::EnsembleLearnerParams).
///
/// ## Prediction Algorithm
///
/// The prediction result is the result of majority voting across the fitted learners.
///
/// ## Example
///
/// This example shows how to train a bagging model using 100 decision trees,
/// each trained on 70% of the training data (bootstrap sampling).
/// ```no_run
/// use linfa::prelude::{Fit, Predict};
/// use linfa_ensemble::EnsembleLearnerParams;
/// use linfa_trees::DecisionTree;
/// use ndarray_rand::rand::SeedableRng;
/// use rand::rngs::SmallRng;
///
/// // Load Iris dataset
/// let mut rng = SmallRng::seed_from_u64(42);
/// let (train, test) = linfa_datasets::iris()
/// .shuffle(&mut rng)
/// .split_with_ratio(0.8);
///
/// // Train the model on the iris dataset
/// let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
/// .ensemble_size(100) // Number of Decision Tree to fit
/// .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
/// .fit(&train)
/// .unwrap();
///
/// // Make predictions on the test set
/// let predictions = bagging_model.predict(&test);
/// ```
///
/// ## References
///
/// * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
/// * [An Introduction to Statistical Learning](https://www.statlearning.com/)
pub struct EnsembleLearner<M> {
pub models: Vec<M>,
pub model_features: Vec<Vec<usize>>,
Expand Down
15 changes: 14 additions & 1 deletion algorithms/linfa-ensemble/src/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@ use linfa_trees::DecisionTreeParams;
use rand::rngs::ThreadRng;
use rand::Rng;

/// The set of valid hyper-parameters that can be specified for the fitting procedure of the
/// [Ensemble Learner](crate::EnsembleLearner).
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerValidParams<P, R> {
/// The number of models in the ensemble
pub ensemble_size: usize,
/// The proportion of the total number of training samples that should be given to each model for training
pub bootstrap_proportion: f64,
/// The proportion of the total number of training feature that should be given to each model for training
/// The proportion of the total number of training features that should be given to each model for training
pub feature_proportion: f64,
/// The model parameters for the base model
pub model_params: P,
pub rng: R,
}

/// A helper struct for building a set of [Ensemble Learner](crate::EnsembleLearner) hyper-parameters.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);

/// A helper struct for building a set of [Random Forest](crate::RandomForest) hyper-parameters.
pub type RandomForestParams<F, L, R> = EnsembleLearnerParams<DecisionTreeParams<F, L>, R>;

impl<P> EnsembleLearnerParams<P, ThreadRng> {
Expand All @@ -41,16 +45,25 @@ impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
})
}

/// Specifies the number of models to fit in the ensemble.
pub fn ensemble_size(mut self, size: usize) -> Self {
self.0.ensemble_size = size;
self
}

/// Sets the proportion of the total number of training samples that should be given to each model for training
///
/// Note that the `proportion` should be in the interval (0, 1] in order to pass the
/// parameter validation check.
pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
self.0.bootstrap_proportion = proportion;
self
}

/// Sets the proportion of the total number of training features that should be given to each model for training
///
/// Note that the `proportion` should be in the interval (0, 1] in order to pass the
/// parameter validation check.
pub fn feature_proportion(mut self, proportion: f64) -> Self {
self.0.feature_proportion = proportion;
self
Expand Down
14 changes: 9 additions & 5 deletions algorithms/linfa-ensemble/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
//! Ensemble methods combine the predictions of several base estimators built with a given
//! learning algorithm in order to improve generalizability / robustness over a single estimator.
//!
//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
//! * [Boostrap Aggregation](EnsembleLearner)
//! * [Random Forest](RandomForest)
//!
//! ## Bootstrap Aggregation (aka Bagging)
//!
//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset.
//!
//! ## Random Forest
//!
//! A special case of Bootstrap Aggregation using decision trees (see `linfa-trees`) with random feature
//! A special case of Bootstrap Aggregation using decision trees (see [`linfa-trees`](linfa_trees)) with random feature
//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
//! the number of available features.
//!
Expand Down Expand Up @@ -48,7 +52,7 @@
//! let predictions = bagging_model.predict(&test);
//! ```
//!
//! This example shows how to train a Random Forest model using 100 decision trees,
//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees,
//! each trained on 70% of the training data (bootstrap sampling) and using only
//! 30% of the available features.
//!
Expand All @@ -66,15 +70,15 @@
//! .split_with_ratio(0.8);
//!
//! // Train the model on the iris dataset
//! let bagging_model = RandomForestParams::new(DecisionTree::params())
//! let random_forest = RandomForestParams::new(DecisionTree::params())
//! .ensemble_size(100) // Number of Decision Tree to fit
//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
//! .feature_proportion(0.3) // Select only 30% of the feature
//! .fit(&train)
//! .unwrap();
//!
//! // Make predictions on the test set
//! let predictions = bagging_model.predict(&test);
//! let predictions = random_forest.predict(&test);
//! ```

mod algorithm;
Expand Down
Loading