Skip to content

Commit c7c2af5

Browse files
authored
Add generic ResidualChain composing method (#430)
* feat: add linfa-residual-sequence crate Implements ResidualSequence Struct and StackWith trait for composing regression models in a boosting / residual-stacking pattern. The second (and any further) model trains on the residuals left by the previous one; predictions are summed. Docs and tests were written with AI assistance. * remove doc link * move to composing/ module in linfa main crate * update docs * implement PredictInplace instead * remove unused param error * use one struct to implement stacking * add deep chain test * Rename to ResidualChain Implement Shrinkage implement paramguard for shrinkage * satisfy zola * can only shrink by if target has the same float type * work with predict inplace only * zola fix * add link in docs * simplify comparison * rename to residual_chain as consistent with struct * implement copy trait * add method `chain` which just chains self with an unshrunk corrector. rename stack_with -> chain_shrunk * add doc
1 parent b1f9ddb commit c7c2af5

3 files changed

Lines changed: 590 additions & 1 deletion

File tree

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ linfa-datasets = { path = "datasets", features = [
6363
"diabetes",
6464
"generate",
6565
] }
66+
linfa-linear = { path = "algorithms/linfa-linear" }
67+
linfa-svm = { path = "algorithms/linfa-svm" }
6668
statrs = "0.18"
6769

6870
[target.'cfg(not(windows))'.dependencies]

src/composing/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
//! Composition models
22
//!
3-
//! This module contains three composition models:
3+
//! This module contains four composition models:
44
//! * `MultiClassModel`: combine multiple binary decision models to a single multi-class model
55
//! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model
66
//! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities
7+
//! * `ResidualChain`: fit models sequentially on the residuals of the previous one
8+
//! (forward stagewise additive modeling / L2Boosting); see [`residual_chain::Stagewise`]
79
mod multi_class_model;
810
mod multi_target_model;
911
pub mod platt_scaling;
12+
pub mod residual_chain;
1013

1114
pub use multi_class_model::MultiClassModel;
1215
pub use multi_target_model::MultiTargetModel;

0 commit comments

Comments
 (0)