Skip to content

Commit c9e3ce4

Browse files
committed
feat: add linfa-residual-sequence crate
Implements ResidualSequence Struct and StackWith trait for composing regression models in a boosting / residual-stacking pattern. The second (and any further) model trains on the residuals left by the previous one; predictions are summed. Docs and tests were written with AI assistance.
1 parent b1f9ddb commit c9e3ce4

2 files changed

Lines changed: 355 additions & 0 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[package]
2+
name = "linfa-residual-sequence"
3+
version = "0.8.1"
4+
edition = "2018"
5+
description = "Model composition utilities for the linfa ML framework"
6+
license = "MIT OR Apache-2.0"
7+
repository = "https://github.com/rust-ml/linfa"
8+
keywords = ["machine-learning", "linfa", "ai", "ml", "residual"]
9+
categories = ["algorithms", "mathematics", "science"]
10+
11+
[dependencies]
12+
linfa = { version = "0.8.1", path = "../.." }
13+
ndarray = { version = "0.16" }
14+
thiserror = "2.0"
15+
16+
[dev-dependencies]
17+
linfa-linear = { path = "../linfa-linear" }
18+
linfa-svm = { path = "../linfa-svm" }
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
//! Residual sequence model composition for the linfa ML framework.
2+
//!
3+
//! This crate provides [`ResidualSequence`], which fits models sequentially on
4+
//! the residuals of the previous one. Chain as many as you like via [`StackWith`]:
5+
//!
6+
//! 1. Fit `first` on `(X, Y)`
7+
//! 2. Compute residuals: `R = Y - first.predict(X)`
8+
//! 3. Fit `second` on `(X, R)`
9+
//! 4. Repeat for any further models stacked on top
10+
//!
11+
//! When predicting, all models' outputs are summed.
12+
//!
13+
//! This is the foundation of boosting / residual stacking.
14+
//!
15+
//! # Examples
16+
//!
17+
//! ## Linear + linear
18+
//!
19+
//! Two [`linfa_linear::LinearRegression`] models stacked: the second fits the
20+
//! residuals left by the first.
21+
//!
22+
//! ```
23+
//! use linfa::traits::{Fit, Predict};
24+
//! use linfa::DatasetBase;
25+
//! use linfa_linear::LinearRegression;
26+
//! use linfa_residual_sequence::{ResidualSequence, StackWith};
27+
//! use ndarray::{array, Array2};
28+
//!
29+
//! // y = 2x: perfectly linear, so the second model should see zero residuals.
30+
//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64);
31+
//! let y = array![0., 2., 4., 6., 8.];
32+
//! let dataset = DatasetBase::new(x.clone(), y);
33+
//!
34+
//! let fitted = LinearRegression::default()
35+
//! .stack_with(LinearRegression::default())
36+
//! .fit(&dataset)
37+
//! .unwrap();
38+
//!
39+
//! let _preds = fitted.predict(&x);
40+
//! ```
41+
//!
42+
//! ## The second model learns nothing when the first fits perfectly
43+
//!
44+
//! If the first model already captures the data exactly, the residuals are all
45+
//! zero and the second model has nothing to learn — its parameters come out
46+
//! at (or very near) zero.
47+
//!
48+
//! ```
49+
//! use linfa::traits::{Fit, Predict};
50+
//! use linfa::DatasetBase;
51+
//! use linfa_linear::LinearRegression;
52+
//! use linfa_residual_sequence::StackWith;
53+
//! use ndarray::{array, Array2};
54+
//!
55+
//! // y = 2x: one linear model is enough to fit this perfectly.
56+
//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64);
57+
//! let y = array![0., 2., 4., 6., 8.];
58+
//! let dataset = DatasetBase::new(x.clone(), y);
59+
//!
60+
//! let fitted = LinearRegression::default()
61+
//! .stack_with(LinearRegression::default())
62+
//! .fit(&dataset)
63+
//! .unwrap();
64+
//!
65+
//! // The second model trained on zero residuals — nothing left to correct.
66+
//! assert!(fitted.second.params().iter().all(|&c: &f64| c.abs() < 1e-10));
67+
//! assert!(fitted.second.intercept().abs() < 1e-10);
68+
//! ```
69+
//!
70+
//! ## Chained SVMs and linear regression
71+
//!
72+
//! A linear-kernel [`linfa_svm::Svm`] captures the overall trend; two
73+
//! Gaussian-kernel SVMs and a [`linfa_linear::LinearRegression`] then fit
74+
//! successive residuals in a four-model chain.
75+
//!
76+
//! ```
77+
//! use linfa::traits::{Fit, Predict};
78+
//! use linfa::DatasetBase;
79+
//! use linfa_linear::LinearRegression;
80+
//! use linfa_residual_sequence::{ResidualSequence, StackWith};
81+
//! use linfa_svm::Svm;
82+
//! use ndarray::Array;
83+
//!
84+
//! // y = sin(x): the linear SVM captures the slope; the RBF SVM captures
85+
//! // the curvature left in the residuals.
86+
//! let x = Array::linspace(0f64, 6., 20)
87+
//! .into_shape_with_order((20, 1))
88+
//! .unwrap();
89+
//! let y = x.column(0).mapv(f64::sin);
90+
//! let dataset = DatasetBase::new(x.clone(), y);
91+
//!
92+
//! let fitted = Svm::<f64, f64>::params()
93+
//! .c_svr(1., None)
94+
//! .linear_kernel()
95+
//! .stack_with(
96+
//! Svm::<f64, f64>::params()
97+
//! .c_svr(10., Some(0.1))
98+
//! .gaussian_kernel(1.),
99+
//! )
100+
//! .stack_with(LinearRegression::default())
101+
//! .stack_with(
102+
//! Svm::<f64, f64>::params()
103+
//! .c_svr(10., Some(0.1))
104+
//! .gaussian_kernel(3.),
105+
//! )
106+
//! .fit(&dataset)
107+
//! .unwrap();
108+
//!
109+
//! let _preds = fitted.predict(&x);
110+
//! ```
111+
112+
use linfa::dataset::{AsTargets, DatasetBase, Records};
113+
use linfa::traits::{Fit, Predict};
114+
use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, RawDataClone};
115+
use std::ops::{Add, Sub};
116+
117+
type Arr2<D> = ArrayBase<D, Ix2>;
118+
119+
/// Error returned by [`ResidualSequence::fit`].
120+
///
121+
/// Wraps the error from whichever of the two model fits failed, keeping them
122+
/// distinguishable without requiring both models to share the same error type.
123+
#[derive(Debug, thiserror::Error)]
124+
pub enum ResidualSequenceError<E1, E2> {
125+
#[error("first model: {0}")]
126+
First(E1),
127+
#[error("second model: {0}")]
128+
Second(E2),
129+
// Satisfies the `Fit` trait's `E: From<linfa::error::Error>` bound.
130+
#[error(transparent)]
131+
Linfa(#[from] linfa::error::Error),
132+
}
133+
134+
/// Fits two models sequentially on the residuals of the first.
135+
///
136+
/// `first` is fit on the original dataset. `second` is fit on the residuals
137+
/// `Y - first.predict(X)`. See the [crate-level docs](crate) for details.
138+
#[derive(Debug, Clone)]
139+
pub struct ResidualSequence<F1, F2> {
140+
pub first: F1,
141+
pub second: F2,
142+
}
143+
144+
/// Extension trait that lets any model params type be composed into a [`ResidualSequence`].
145+
///
146+
/// # Example
147+
///
148+
/// ```
149+
/// use linfa::traits::Fit;
150+
/// use linfa::DatasetBase;
151+
/// use linfa_linear::LinearRegression;
152+
/// use linfa_residual_sequence::StackWith;
153+
/// use ndarray::{array, Array2};
154+
///
155+
/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64);
156+
/// let y = array![0., 2., 4., 6., 8.];
157+
/// let dataset = DatasetBase::new(x.clone(), y);
158+
///
159+
/// let fitted = LinearRegression::default()
160+
/// .stack_with(LinearRegression::default())
161+
/// .fit(&dataset)
162+
/// .unwrap();
163+
/// ```
164+
pub trait StackWith: Sized {
165+
fn stack_with<F2>(self, second: F2) -> ResidualSequence<Self, F2>;
166+
}
167+
168+
impl<F1> StackWith for F1 {
169+
fn stack_with<F2>(self, second: F2) -> ResidualSequence<F1, F2> {
170+
ResidualSequence {
171+
first: self,
172+
second,
173+
}
174+
}
175+
}
176+
177+
/// Two fitted models produced by [`ResidualSequence::fit`].
178+
///
179+
/// Predicts by summing both models' outputs: `first.predict(X) + second.predict(X)`.
180+
#[derive(Debug, Clone)]
181+
pub struct FittedResidualSequence<R1, R2> {
182+
pub first: R1,
183+
pub second: R2,
184+
}
185+
186+
impl<F1, F2, D, T, E1, E2> Fit<Arr2<D>, T, ResidualSequenceError<E1, E2>>
187+
for ResidualSequence<F1, F2>
188+
where
189+
D: Data + RawDataClone,
190+
D::Elem: Copy + Sub<Output = D::Elem>,
191+
Arr2<D>: Records,
192+
F1: Fit<Arr2<D>, T, E1>,
193+
for<'a> F1::Object: Predict<&'a Arr2<D>, Array1<D::Elem>>,
194+
F2: Fit<Arr2<D>, Array1<D::Elem>, E2>,
195+
T: AsTargets<Elem = D::Elem, Ix = Ix1>,
196+
E1: std::error::Error + From<linfa::error::Error>,
197+
E2: std::error::Error + From<linfa::error::Error>,
198+
{
199+
type Object = FittedResidualSequence<F1::Object, F2::Object>;
200+
201+
fn fit(
202+
&self,
203+
dataset: &DatasetBase<Arr2<D>, T>,
204+
) -> Result<Self::Object, ResidualSequenceError<E1, E2>> {
205+
let first = self
206+
.first
207+
.fit(dataset)
208+
.map_err(ResidualSequenceError::First)?;
209+
210+
let y_pred = first.predict(dataset.records());
211+
let residuals = dataset
212+
.targets()
213+
.as_targets()
214+
.iter()
215+
.zip(y_pred.iter())
216+
.map(|(y, p)| *y - *p)
217+
.collect::<Array1<D::Elem>>();
218+
219+
let residual_dataset = DatasetBase::new(dataset.records().clone(), residuals);
220+
let second = self
221+
.second
222+
.fit(&residual_dataset)
223+
.map_err(ResidualSequenceError::Second)?;
224+
225+
Ok(FittedResidualSequence { first, second })
226+
}
227+
}
228+
229+
impl<'a, R1, R2, D> Predict<&'a Arr2<D>, Array1<D::Elem>> for FittedResidualSequence<R1, R2>
230+
where
231+
D: Data,
232+
D::Elem: Copy + Add<Output = D::Elem>,
233+
Arr2<D>: Records,
234+
for<'b> R1: Predict<&'b Arr2<D>, Array1<D::Elem>>,
235+
for<'b> R2: Predict<&'b Arr2<D>, Array1<D::Elem>>,
236+
{
237+
fn predict(&self, x: &'a Arr2<D>) -> Array1<D::Elem> {
238+
let pred1 = self.first.predict(x);
239+
let pred2 = self.second.predict(x);
240+
pred1 + pred2
241+
}
242+
}
243+
244+
#[cfg(test)]
245+
mod tests {
246+
use super::*;
247+
use linfa::error::Error as LinfaError;
248+
use linfa::DatasetBase;
249+
use ndarray::{array, Array1, Array2};
250+
251+
#[derive(thiserror::Error, Debug)]
252+
#[error("dummy error")]
253+
struct DummyError(#[from] LinfaError);
254+
255+
// Params that fits by recording the mean of the targets.
256+
struct MeanParams;
257+
258+
// Model that predicts the mean it saw during fit.
259+
struct MeanModel(f64);
260+
261+
impl Fit<Array2<f64>, Array1<f64>, DummyError> for MeanParams {
262+
type Object = MeanModel;
263+
fn fit(
264+
&self,
265+
dataset: &DatasetBase<Array2<f64>, Array1<f64>>,
266+
) -> Result<MeanModel, DummyError> {
267+
let mean = dataset.targets().iter().sum::<f64>() / dataset.targets().len() as f64;
268+
Ok(MeanModel(mean))
269+
}
270+
}
271+
272+
impl<'a> Predict<&'a Array2<f64>, Array1<f64>> for MeanModel {
273+
fn predict(&self, x: &'a Array2<f64>) -> Array1<f64> {
274+
Array1::from_elem(x.nrows(), self.0)
275+
}
276+
}
277+
278+
#[test]
279+
fn second_is_fit_on_residuals() {
280+
// targets = [1, 3]. first sees mean=2, predicts 2 for all.
281+
// residuals = [1-2, 3-2] = [-1, 1]. second sees mean=0.
282+
let model = ResidualSequence {
283+
first: MeanParams,
284+
second: MeanParams,
285+
};
286+
let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]);
287+
let fitted = model.fit(&dataset).unwrap();
288+
289+
assert_eq!(fitted.first.0, 2.0); // mean of [1, 3]
290+
assert_eq!(fitted.second.0, 0.0); // mean of residuals [-1, 1]
291+
}
292+
293+
#[test]
294+
fn predict_sums_both_models() {
295+
// first predicts 2.0, second predicts 0.0 → sum = 2.0
296+
let model = MeanParams.stack_with(MeanParams);
297+
let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]);
298+
let fitted = model.fit(&dataset).unwrap();
299+
300+
let records = array![[0.0_f64], [1.0]];
301+
let predictions = fitted.predict(&records);
302+
assert_eq!(predictions, array![2.0, 2.0]);
303+
}
304+
305+
#[test]
306+
fn predict_recovers_targets_when_residuals_fit_perfectly() {
307+
// If second perfectly fits the residuals, the combined prediction = original targets.
308+
struct FixedParams(f64);
309+
struct FixedModel(f64);
310+
311+
impl Fit<Array2<f64>, Array1<f64>, DummyError> for FixedParams {
312+
type Object = FixedModel;
313+
fn fit(
314+
&self,
315+
_dataset: &DatasetBase<Array2<f64>, Array1<f64>>,
316+
) -> Result<FixedModel, DummyError> {
317+
Ok(FixedModel(self.0))
318+
}
319+
}
320+
321+
impl<'a> Predict<&'a Array2<f64>, Array1<f64>> for FixedModel {
322+
fn predict(&self, x: &'a Array2<f64>) -> Array1<f64> {
323+
Array1::from_elem(x.nrows(), self.0)
324+
}
325+
}
326+
327+
// first predicts 3.0, second predicts 1.0 → sum = 4.0
328+
let model = FixedParams(3.0)
329+
.stack_with(FixedParams(1.0))
330+
.stack_with(FixedParams(0.0));
331+
let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]);
332+
let fitted = model.fit(&dataset).unwrap();
333+
334+
let predictions = fitted.predict(&array![[0.0_f64], [1.0]]);
335+
assert_eq!(predictions, array![4.0, 4.0]);
336+
}
337+
}

0 commit comments

Comments
 (0)