forked from rust-ml/linfa
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhyperparams.rs
More file actions
73 lines (64 loc) · 2.17 KB
/
hyperparams.rs
File metadata and controls
73 lines (64 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use linfa::{
error::{Error, Result},
ParamGuard,
};
use rand::rngs::ThreadRng;
use rand::Rng;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerValidParams<P, R> {
/// The number of models in the ensemble
pub ensemble_size: usize,
/// The proportion of the total number of training samples that should be given to each model for training
pub bootstrap_proportion: f64,
/// The model parameters for the base model
pub model_params: P,
pub rng: R,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleLearnerParams<P, R>(EnsembleLearnerValidParams<P, R>);
impl<P> EnsembleLearnerParams<P, ThreadRng> {
pub fn new(model_params: P) -> EnsembleLearnerParams<P, ThreadRng> {
Self::new_fixed_rng(model_params, rand::thread_rng())
}
}
impl<P, R: Rng + Clone> EnsembleLearnerParams<P, R> {
pub fn new_fixed_rng(model_params: P, rng: R) -> EnsembleLearnerParams<P, R> {
Self(EnsembleLearnerValidParams {
ensemble_size: 1,
bootstrap_proportion: 1.0,
model_params,
rng,
})
}
pub fn ensemble_size(mut self, size: usize) -> Self {
self.0.ensemble_size = size;
self
}
pub fn bootstrap_proportion(mut self, proportion: f64) -> Self {
self.0.bootstrap_proportion = proportion;
self
}
}
impl<P, R> ParamGuard for EnsembleLearnerParams<P, R> {
type Checked = EnsembleLearnerValidParams<P, R>;
type Error = Error;
fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.bootstrap_proportion > 1.0 || self.0.bootstrap_proportion <= 0.0 {
Err(Error::Parameters(format!(
"Bootstrap proportion should be greater than zero and less than or equal to one, but was {}",
self.0.bootstrap_proportion
)))
} else if self.0.ensemble_size < 1 {
Err(Error::Parameters(format!(
"Ensemble size should be less than one, but was {}",
self.0.ensemble_size
)))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}