Skip to content

Commit 9b5c424

Browse files
relfrathideep22claude
authored
Add AdaBoost classifier to linfa-ensemble (#427)
* feat: Add AdaBoost (Adaptive Boosting) to linfa-ensemble Implements SAMME (Stagewise Additive Modeling using a Multiclass Exponential loss function) algorithm for multi-class classification using ensemble learning. ## Features - Sequential boosting with adaptive sample weighting - Multi-class classification support (SAMME algorithm) - Weighted voting for final predictions using model alpha values - Automatic convergence handling and early stopping - Resampling-based approach compatible with any base learner ## Implementation Details - AdaBoost struct with model weights (alpha values) tracking - AdaBoostParams following ParamGuard pattern for validation - Configurable n_estimators and learning_rate hyperparameters - Full trait implementations: Fit, Predict, PredictInplace - Comprehensive error handling with proper error types ## Testing - 12 unit tests covering parameter validation and model training - 6 doc tests for API documentation - Achieves 90-93% accuracy on Iris dataset with decision stumps - Tests for different learning rates and tree depths ## Documentation - Extensive inline documentation with algorithm explanation - Working example (adaboost_iris.rs) with multiple configurations - References to original AdaBoost paper (Freund & Schapire, 1997) - Comparison with scikit-learn implementation ## Performance - Successfully trains on Iris dataset (150 samples, 3 classes) - Supports decision stumps (depth=1) and shallow trees - Model weights properly reflect learner performance 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * fix: remove redundant explicit link target in rustdoc Fixes rustdoc warning about redundant explicit links. Changed [AdaBoost](AdaBoost) to [AdaBoost] as recommended by rustdoc linter. * test: add tests for edge cases to improve coverage Adds three new tests to improve code coverage: - test_adaboost_early_stopping_on_perfect_fit: Tests early stopping on linearly separable data - test_adaboost_single_class_error: Tests error handling for single-class datasets - test_adaboost_classes_method: Tests that classes are properly identified This should improve patch coverage from 81.69% to ~85%+ * style: apply rustfmt formatting Fix import ordering and line wrapping to match rustfmt standards. * fix: address code review feedback for AdaBoost implementation Implements all requested changes from PR review: 1. Replace rand:: imports with ndarray_rand::rand:: for consistency 2. Change sample_weights from f32 to f64 for better precision 3. Fix learning_rate cancellation bug in weight update formula - Previously: weight *= ((alpha / learning_rate) as f32).exp() - Now: weight *= alpha.exp() - This ensures learning_rate actually affects sample weight updates 4. Fix classes field to store actual labels (T::Elem) instead of usize - Made AdaBoost struct generic over label type L - Stores original class labels for proper type safety 5. Remove duplicate y_array definition in predict_inplace 6. Add base learner error details to error message for better debugging 7. Add test_adaboost_different_learning_rates to verify learning_rate effects on model weights All tests passing with no warnings or clippy issues. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Cleanup f64 * No tie breaking * Avoid magic number --------- Co-authored-by: Deep Rathi <deeprathi222@gmail.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent db3cade commit 9b5c424

4 files changed

Lines changed: 888 additions & 0 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
2+
use linfa_ensemble::AdaBoostParams;
3+
use linfa_trees::DecisionTree;
4+
use ndarray_rand::rand::SeedableRng;
5+
use rand::rngs::SmallRng;
6+
7+
fn adaboost_with_stumps(n_estimators: usize, learning_rate: f64) {
8+
// Load dataset
9+
let mut rng = SmallRng::seed_from_u64(42);
10+
let (train, test) = linfa_datasets::iris()
11+
.shuffle(&mut rng)
12+
.split_with_ratio(0.8);
13+
14+
// Train AdaBoost model with decision tree stumps (max_depth=1)
15+
// Stumps are weak learners commonly used with AdaBoost
16+
let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
17+
.n_estimators(n_estimators)
18+
.learning_rate(learning_rate)
19+
.fit(&train)
20+
.unwrap();
21+
22+
// Make predictions
23+
let predictions = model.predict(&test);
24+
println!("Final Predictions: \n{predictions:?}");
25+
26+
let cm = predictions.confusion_matrix(&test).unwrap();
27+
println!("{cm:?}");
28+
println!(
29+
"Test accuracy: {:.2}%\nwith Decision Tree stumps (max_depth=1),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n",
30+
100.0 * cm.accuracy()
31+
);
32+
println!("Number of models trained: {}", model.n_estimators());
33+
}
34+
35+
fn adaboost_with_shallow_trees(n_estimators: usize, learning_rate: f64, max_depth: usize) {
36+
let mut rng = SmallRng::seed_from_u64(42);
37+
let (train, test) = linfa_datasets::iris()
38+
.shuffle(&mut rng)
39+
.split_with_ratio(0.8);
40+
41+
// Train AdaBoost model with shallow decision trees
42+
let model =
43+
AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(max_depth)), rng)
44+
.n_estimators(n_estimators)
45+
.learning_rate(learning_rate)
46+
.fit(&train)
47+
.unwrap();
48+
49+
// Make predictions
50+
let predictions = model.predict(&test);
51+
println!("Final Predictions: \n{predictions:?}");
52+
53+
let cm = predictions.confusion_matrix(&test).unwrap();
54+
println!("{cm:?}");
55+
println!(
56+
"Test accuracy: {:.2}%\nwith Decision Trees (max_depth={max_depth}),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n",
57+
100.0 * cm.accuracy()
58+
);
59+
60+
// Display model weights
61+
println!("Model weights (alpha values):");
62+
for (i, weight) in model.weights().iter().enumerate() {
63+
println!(" Model {}: {:.4}", i + 1, weight);
64+
}
65+
println!();
66+
}
67+
68+
fn main() {
69+
println!("{}", "=".repeat(80));
70+
println!("AdaBoost Examples on Iris Dataset");
71+
println!("{}", "=".repeat(80));
72+
println!();
73+
74+
// Example 1: AdaBoost with decision stumps (most common configuration)
75+
println!("Example 1: AdaBoost with Decision Stumps");
76+
println!("{}", "-".repeat(80));
77+
adaboost_with_stumps(50, 1.0);
78+
println!();
79+
80+
// Example 2: AdaBoost with lower learning rate
81+
println!("Example 2: AdaBoost with Lower Learning Rate");
82+
println!("{}", "-".repeat(80));
83+
adaboost_with_stumps(100, 0.5);
84+
println!();
85+
86+
// Example 3: AdaBoost with shallow trees
87+
println!("Example 3: AdaBoost with Shallow Decision Trees");
88+
println!("{}", "-".repeat(80));
89+
adaboost_with_shallow_trees(50, 1.0, 2);
90+
println!();
91+
92+
// Example 4: Comparing different configurations
93+
println!("Example 4: Comparing Configurations");
94+
println!("{}", "-".repeat(80));
95+
let configs = vec![
96+
(25, 1.0, 1, "Few stumps, high learning rate"),
97+
(50, 1.0, 1, "Medium stumps, high learning rate"),
98+
(100, 0.5, 1, "Many stumps, low learning rate"),
99+
(50, 1.0, 2, "Shallow trees, high learning rate"),
100+
];
101+
102+
for (n_est, lr, depth, desc) in configs {
103+
let mut rng = SmallRng::seed_from_u64(42);
104+
let (train, test) = linfa_datasets::iris()
105+
.shuffle(&mut rng)
106+
.split_with_ratio(0.8);
107+
108+
let model =
109+
AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(depth)), rng)
110+
.n_estimators(n_est)
111+
.learning_rate(lr)
112+
.fit(&train)
113+
.unwrap();
114+
115+
let predictions = model.predict(&test);
116+
let cm = predictions.confusion_matrix(&test).unwrap();
117+
118+
println!(
119+
"{desc:50} => Accuracy: {:.2}% (models trained: {})",
120+
100.0 * cm.accuracy(),
121+
model.n_estimators()
122+
);
123+
}
124+
125+
println!();
126+
println!("{}", "=".repeat(80));
127+
println!("Notes:");
128+
println!("- AdaBoost works by training weak learners sequentially");
129+
println!("- Each learner focuses on samples misclassified by previous learners");
130+
println!("- Decision stumps (depth=1) are the most common weak learners");
131+
println!("- Lower learning_rate provides regularization but needs more estimators");
132+
println!("- Model weights (alpha) reflect each learner's contribution to prediction");
133+
println!("{}", "=".repeat(80));
134+
}

0 commit comments

Comments
 (0)