forked from rust-ml/linfa
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalgorithm.rs
More file actions
177 lines (162 loc) · 5.93 KB
/
algorithm.rs
File metadata and controls
177 lines (162 loc) · 5.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
use crate::EnsembleLearnerValidParams;
use linfa::{
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned, Records},
error::Error,
traits::*,
DatasetBase,
};
use linfa_trees::DecisionTree;
use ndarray::{Array2, Axis, Zip};
use rand::Rng;
use std::{cmp::Eq, collections::HashMap, hash::Hash};
/// A fitted ensemble of [Decision Trees](DecisionTree) trained on a random subset of features.
///
/// Check out [EnsembleLearner] documentation for more information regarding [RandomForest] interface.
pub type RandomForest<F, L> = EnsembleLearner<DecisionTree<F, L>>;
/// A fitted ensemble of learners for classification.
///
/// ## Structure
///
/// An Ensemble Learner is composed of a collection of fitted models of type `M`.
///
/// ## Fitting Algorithm
///
/// Given a [DatasetBase](DatasetBase) denoted as `D`,
/// 1. Create as many distinct bootstrapped subset of the original dataset `D` as number of
/// distinct model to fit.
/// 2. Fit each distinct model on a distinct bootstrapped subset of `D`.
///
/// Note that the subset size, as well as the subset of feature to use in each training subset can
/// be specified in the [parameters](crate::EnsembleLearnerParams).
///
/// ## Prediction Algorithm
///
/// The prediction result is the result of majority voting across the fitted learners.
///
/// ## Example
///
/// This example shows how to train a bagging model using 100 decision trees,
/// each trained on 70% of the training data (bootstrap sampling).
/// ```no_run
/// use linfa::prelude::{Fit, Predict};
/// use linfa_ensemble::EnsembleLearnerParams;
/// use linfa_trees::DecisionTree;
/// use ndarray_rand::rand::SeedableRng;
/// use rand::rngs::SmallRng;
///
/// // Load Iris dataset
/// let mut rng = SmallRng::seed_from_u64(42);
/// let (train, test) = linfa_datasets::iris()
/// .shuffle(&mut rng)
/// .split_with_ratio(0.8);
///
/// // Train the model on the iris dataset
/// let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
/// .ensemble_size(100) // Number of Decision Tree to fit
/// .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
/// .fit(&train)
/// .unwrap();
///
/// // Make predictions on the test set
/// let predictions = bagging_model.predict(&test);
/// ```
///
/// ## References
///
/// * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
/// * [An Introduction to Statistical Learning](https://www.statlearning.com/)
pub struct EnsembleLearner<M> {
pub models: Vec<M>,
pub model_features: Vec<Vec<usize>>,
}
impl<M> EnsembleLearner<M> {
// Generates prediction iterator returning predictions from each model
pub fn generate_predictions<'b, R: Records, T>(
&'b self,
x: &'b [R],
) -> impl Iterator<Item = T> + 'b
where
M: Predict<&'b R, T>,
{
self.models
.iter()
.zip(x.iter())
.map(move |(m, sub_data)| m.predict(sub_data))
}
}
impl<F: Clone, T, M> PredictInplace<Array2<F>, T> for EnsembleLearner<M>
where
M: PredictInplace<Array2<F>, T>,
<T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug,
T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
{
fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
let y_array = y.as_targets();
assert_eq!(
x.nrows(),
y_array.len_of(Axis(0)),
"The number of data points must match the number of outputs."
);
let sub_datas = self
.model_features
.iter()
.map(|feat| x.select(Axis(1), feat))
.collect::<Vec<_>>();
let predictions = self.generate_predictions(&sub_datas);
// prediction map has same shape as y_array, but the elements are maps
let mut prediction_maps = y_array.map(|_| HashMap::new());
for prediction in predictions {
let p_arr = prediction.as_targets();
assert_eq!(p_arr.shape(), y_array.shape());
// Insert each prediction value into the corresponding map
Zip::from(&mut prediction_maps)
.and(&p_arr)
.for_each(|map, val| *map.entry(*val).or_insert(0) += 1);
}
// For each prediction, pick the result with the highest number of votes
let agg_preds = prediction_maps.map(|map| map.iter().max_by_key(|(_, v)| **v).unwrap().0);
let mut y_array = y.as_targets_mut();
for (y, pred) in y_array.iter_mut().zip(agg_preds.iter()) {
*y = **pred
}
}
fn default_target(&self, x: &Array2<F>) -> T {
self.models[0].default_target(x)
}
}
impl<D, T, P: Fit<Array2<D>, T::Owned, Error>, R: Rng + Clone> Fit<Array2<D>, T, Error>
for EnsembleLearnerValidParams<P, R>
where
D: Clone,
T: FromTargetArrayOwned,
T::Elem: Copy + Eq + Hash,
T::Owned: AsTargets,
{
type Object = EnsembleLearner<P::Object>;
fn fit(
&self,
dataset: &DatasetBase<Array2<D>, T>,
) -> core::result::Result<Self::Object, Error> {
let mut models = Vec::with_capacity(self.ensemble_size);
let mut model_features = Vec::with_capacity(self.ensemble_size);
let mut rng = self.rng.clone();
// Compute dataset and the subset of features ratio to be selected
let dataset_size =
((dataset.records.nrows() as f64) * self.bootstrap_proportion).ceil() as usize;
let n_feat = dataset.records.ncols();
let n_sub = ((n_feat as f64) * self.feature_proportion).ceil() as usize;
let iter = dataset.bootstrap_with_indices((dataset_size, n_sub), &mut rng);
for (train, _, feature_selected) in iter {
let model = self.model_params.fit(&train).unwrap();
models.push(model);
model_features.push(feature_selected);
if models.len() == self.ensemble_size {
break;
}
}
Ok(EnsembleLearner {
models,
model_features,
})
}
}