Skip to content

Commit 5f6a312

Browse files
committed
test: add tests for edge cases to improve coverage
Adds three new tests to improve code coverage: - test_adaboost_early_stopping_on_perfect_fit: Tests early stopping on linearly separable data - test_adaboost_single_class_error: Tests error handling for single-class datasets - test_adaboost_classes_method: Tests that classes are properly identified This should improve patch coverage from 81.69% to ~85%+
1 parent 117f18e commit 5f6a312

1 file changed

Lines changed: 78 additions & 0 deletions

File tree

  • algorithms/linfa-ensemble/src

algorithms/linfa-ensemble/src/lib.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,82 @@ mod tests {
223223
// Verify we have the expected number of models
224224
assert_eq!(model.n_estimators(), 10);
225225
}
226+
227+
#[test]
228+
fn test_adaboost_early_stopping_on_perfect_fit() {
229+
use ndarray::Array2;
230+
use linfa::DatasetBase;
231+
232+
// Create a simple linearly separable dataset
233+
let records = Array2::from_shape_vec(
234+
(6, 2),
235+
vec![
236+
0.0, 0.0, // class 0
237+
0.1, 0.1, // class 0
238+
0.2, 0.2, // class 0
239+
1.0, 1.0, // class 1
240+
1.1, 1.1, // class 1
241+
1.2, 1.2, // class 1
242+
],
243+
)
244+
.unwrap();
245+
let targets = ndarray::array![0, 0, 0, 1, 1, 1];
246+
let dataset = DatasetBase::new(records, targets);
247+
248+
let rng = SmallRng::seed_from_u64(42);
249+
let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(3)), rng)
250+
.n_estimators(50)
251+
.fit(&dataset)
252+
.unwrap();
253+
254+
// Should stop early due to perfect classification
255+
assert!(
256+
model.n_estimators() < 50,
257+
"Expected early stopping, but got {} estimators",
258+
model.n_estimators()
259+
);
260+
}
261+
262+
#[test]
263+
fn test_adaboost_single_class_error() {
264+
use ndarray::Array2;
265+
use linfa::DatasetBase;
266+
267+
// Create dataset with only one class
268+
let records = Array2::from_shape_vec(
269+
(4, 2),
270+
vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3],
271+
)
272+
.unwrap();
273+
let targets = ndarray::array![0, 0, 0, 0]; // All same class
274+
let dataset = DatasetBase::new(records, targets);
275+
276+
let rng = SmallRng::seed_from_u64(42);
277+
let result = AdaBoostParams::new_fixed_rng(DecisionTree::params(), rng)
278+
.n_estimators(10)
279+
.fit(&dataset);
280+
281+
assert!(
282+
result.is_err(),
283+
"Should fail with single class dataset"
284+
);
285+
}
286+
287+
#[test]
288+
fn test_adaboost_classes_method() {
289+
let mut rng = SmallRng::seed_from_u64(42);
290+
let (train, _) = linfa_datasets::iris()
291+
.shuffle(&mut rng)
292+
.split_with_ratio(0.8);
293+
294+
let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
295+
.n_estimators(10)
296+
.fit(&train)
297+
.unwrap();
298+
299+
// Verify classes are properly stored
300+
let classes = &model.classes;
301+
assert_eq!(classes.len(), 3, "Iris has 3 classes");
302+
assert_eq!(classes, &vec![0, 1, 2], "Classes should be [0, 1, 2]");
303+
}
226304
}

0 commit comments

Comments
 (0)