Skip to content

Commit fc61f47

Browse files
committed
📜 Add docs and example for RandomForest type alias.
1 parent 359e015 commit fc61f47

1 file changed

Lines changed: 39 additions & 3 deletions

File tree

  • algorithms/linfa-ensemble/src

algorithms/linfa-ensemble/src/lib.rs

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
//!
66
//! ## Bootstrap Aggregation (aka Bagging)
77
//!
8-
//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of
8+
//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
99
//! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset.
1010
//!
11+
//! ## Random Forest
12+
//!
13+
//! A special case of Bootstrap Aggregation using decision trees (see `linfa-trees`) with random feature
14+
//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
15+
//! the number of available features.
16+
//!
1117
//! ## Reference
1218
//!
1319
//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
20+
//! * [An Introduction to Statistical Learning](https://www.statlearning.com/)
1421
//!
1522
//! ## Example
1623
//!
@@ -32,15 +39,44 @@
3239
//!
3340
//! // Train the model on the iris dataset
3441
//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
35-
//! .ensemble_size(100)
36-
//! .bootstrap_proportion(0.7)
42+
//! .ensemble_size(100) // Number of Decision Tree to fit
43+
//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
3744
//! .fit(&train)
3845
//! .unwrap();
3946
//!
4047
//! // Make predictions on the test set
4148
//! let predictions = bagging_model.predict(&test);
4249
//! ```
4350
//!
51+
//! This example shows how to train a Random Forest model using 100 decision trees,
52+
//! each trained on 70% of the training data (bootstrap sampling) and using only
53+
//! 30% of the available features.
54+
//!
55+
//! ```no_run
56+
//! use linfa::prelude::{Fit, Predict};
57+
//! use linfa_ensemble::RandomForestParams;
58+
//! use linfa_trees::DecisionTree;
59+
//! use ndarray_rand::rand::SeedableRng;
60+
//! use rand::rngs::SmallRng;
61+
//!
62+
//! // Load Iris dataset
63+
//! let mut rng = SmallRng::seed_from_u64(42);
64+
//! let (train, test) = linfa_datasets::iris()
65+
//! .shuffle(&mut rng)
66+
//! .split_with_ratio(0.8);
67+
//!
68+
//! // Train the model on the iris dataset
69+
//! let bagging_model = RandomForestParams::new(DecisionTree::params())
70+
//! .ensemble_size(100) // Number of Decision Tree to fit
71+
//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
72+
//! .feature_proportion(0.3) // Select only 30% of the feature
73+
//! .fit(&train)
74+
//! .unwrap();
75+
//!
76+
//! // Make predictions on the test set
77+
//! let predictions = bagging_model.predict(&test);
78+
//! ```
79+
4480
mod algorithm;
4581
mod hyperparams;
4682

0 commit comments

Comments
 (0)