|
5 | 5 | //! |
6 | 6 | //! ## Bootstrap Aggregation (aka Bagging) |
7 | 7 | //! |
8 | | -//! A typical example of ensemble method is Bootstrapo AGgregation, which combines the predictions of |
| 8 | +//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of |
9 | 9 | //! several decision trees (see `linfa-trees`) trained on different samples subset of the training dataset. |
10 | 10 | //! |
| 11 | +//! ## Random Forest |
| 12 | +//! |
| 13 | +//! A special case of Bootstrap Aggregation using decision trees (see `linfa-trees`) with random feature |
| 14 | +//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being |
| 15 | +//! the number of available features. |
| 16 | +//! |
11 | 17 | //! ## Reference |
12 | 18 | //! |
13 | 19 | //! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html) |
| 20 | +//! * [An Introduction to Statistical Learning](https://www.statlearning.com/) |
14 | 21 | //! |
15 | 22 | //! ## Example |
16 | 23 | //! |
|
32 | 39 | //! |
33 | 40 | //! // Train the model on the iris dataset |
34 | 41 | //! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params()) |
35 | | -//! .ensemble_size(100) |
36 | | -//! .bootstrap_proportion(0.7) |
| 42 | +//! .ensemble_size(100) // Number of Decision Tree to fit |
| 43 | +//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap |
37 | 44 | //! .fit(&train) |
38 | 45 | //! .unwrap(); |
39 | 46 | //! |
40 | 47 | //! // Make predictions on the test set |
41 | 48 | //! let predictions = bagging_model.predict(&test); |
42 | 49 | //! ``` |
43 | 50 | //! |
| 51 | +//! This example shows how to train a Random Forest model using 100 decision trees, |
| 52 | +//! each trained on 70% of the training data (bootstrap sampling) and using only |
| 53 | +//! 30% of the available features. |
| 54 | +//! |
| 55 | +//! ```no_run |
| 56 | +//! use linfa::prelude::{Fit, Predict}; |
| 57 | +//! use linfa_ensemble::RandomForestParams; |
| 58 | +//! use linfa_trees::DecisionTree; |
| 59 | +//! use ndarray_rand::rand::SeedableRng; |
| 60 | +//! use rand::rngs::SmallRng; |
| 61 | +//! |
| 62 | +//! // Load Iris dataset |
| 63 | +//! let mut rng = SmallRng::seed_from_u64(42); |
| 64 | +//! let (train, test) = linfa_datasets::iris() |
| 65 | +//! .shuffle(&mut rng) |
| 66 | +//! .split_with_ratio(0.8); |
| 67 | +//! |
| 68 | +//! // Train the model on the iris dataset |
| 69 | +//! let bagging_model = RandomForestParams::new(DecisionTree::params()) |
| 70 | +//! .ensemble_size(100) // Number of Decision Tree to fit |
| 71 | +//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap |
| 72 | +//! .feature_proportion(0.3) // Select only 30% of the feature |
| 73 | +//! .fit(&train) |
| 74 | +//! .unwrap(); |
| 75 | +//! |
| 76 | +//! // Make predictions on the test set |
| 77 | +//! let predictions = bagging_model.predict(&test); |
| 78 | +//! ``` |
| 79 | +
|
44 | 80 | mod algorithm; |
45 | 81 | mod hyperparams; |
46 | 82 |
|
|
0 commit comments