From c971f0e1bef962d387f1bead326fe9b0582eafd1 Mon Sep 17 00:00:00 2001 From: espenblyn Date: Thu, 19 Mar 2026 15:42:27 +0100 Subject: [PATCH] feat: add offset support to binary logistic regression --- algorithms/linfa-logistic/src/error.rs | 5 +- algorithms/linfa-logistic/src/hyperparams.rs | 14 ++- algorithms/linfa-logistic/src/lib.rs | 97 ++++++++++++++++++-- 3 files changed, 108 insertions(+), 8 deletions(-) diff --git a/algorithms/linfa-logistic/src/error.rs b/algorithms/linfa-logistic/src/error.rs index c4a45659e..4eca04e21 100644 --- a/algorithms/linfa-logistic/src/error.rs +++ b/algorithms/linfa-logistic/src/error.rs @@ -19,11 +19,14 @@ pub enum Error { InitialParameterFeaturesMismatch { rows: usize, n_features: usize }, #[error("Columns of initial parameter ({cols}) must be the same as the number of classes ({n_classes})")] InitialParameterClassesMismatch { cols: usize, n_classes: usize }, - #[error("gradient_tolerance must be a positive, finite number")] InvalidGradientTolerance, #[error("alpha must be a positive, finite number")] InvalidAlpha, #[error("Initial parameters must be finite")] InvalidInitialParameters, + #[error("Offset must be finite")] + InvalidOffset, + #[error("Offset length ({offset_len}) must match the number of samples ({n_samples})")] + OffsetLengthMismatch { offset_len: usize, n_samples: usize }, } diff --git a/algorithms/linfa-logistic/src/hyperparams.rs b/algorithms/linfa-logistic/src/hyperparams.rs index 32c072477..52ea52066 100644 --- a/algorithms/linfa-logistic/src/hyperparams.rs +++ b/algorithms/linfa-logistic/src/hyperparams.rs @@ -1,5 +1,5 @@ use linfa::ParamGuard; -use ndarray::{Array, Dimension}; +use ndarray::{Array, Array1, Dimension}; use crate::error::Error; use crate::float::Float; @@ -29,6 +29,7 @@ pub struct LogisticRegressionValidParams { pub(crate) max_iterations: u64, pub(crate) gradient_tolerance: F, pub(crate) initial_params: Option>, + pub(crate) offset: Option>, } impl ParamGuard for LogisticRegressionParams { @@ -47,6 +48,11 @@ impl ParamGuard for LogisticRegressionParams { return Err(Error::InvalidInitialParameters); } } + if let Some(ref offset) = self.0.offset { + if offset.iter().any(|o| !o.is_finite()) { + return Err(Error::InvalidOffset); + } + } Ok(&self.0) } @@ -65,6 +71,7 @@ impl LogisticRegressionParams { max_iterations: 100, gradient_tolerance: F::cast(1e-4), initial_params: None, + offset: None, }) } @@ -104,4 +111,9 @@ impl LogisticRegressionParams { self.0.initial_params = Some(params); self } + + pub fn offset(mut self, offset: Array1) -> Self { + self.0.offset = Some(offset); + self + } } diff --git a/algorithms/linfa-logistic/src/lib.rs b/algorithms/linfa-logistic/src/lib.rs index ea2386d91..36fe09cf0 100644 --- a/algorithms/linfa-logistic/src/lib.rs +++ b/algorithms/linfa-logistic/src/lib.rs @@ -174,6 +174,7 @@ impl LogisticRegressionValidParams { x, target, alpha: self.alpha, + offset: self.offset.clone(), } } @@ -232,6 +233,16 @@ impl, T: AsSingleTargets> let (x, y) = (dataset.records(), dataset.targets()); let (labels, target) = label_classes(y)?; self.validate_data(x, &target)?; + + if let Some(ref offset) = self.offset { + if offset.len() != x.nrows() { + return Err(Error::OffsetLengthMismatch { + offset_len: offset.len(), + n_samples: x.nrows(), + }); + } + } + let problem = self.setup_problem(x, target); let solver = self.setup_solver(); let init_params = self.setup_init_params(x.ncols()); @@ -464,12 +475,19 @@ fn logistic_loss>( y: &Array1, alpha: F, w: &Array1, + offset: Option<&Array1>, ) -> F { let n_features = x.shape()[1]; let (params, intercept) = convert_params(n_features, w); let yz = x.dot(¶ms.into_shape_with_order((params.len(), 1)).unwrap()) + intercept; let len = yz.len(); - let mut yz = yz.into_shape_with_order(len).unwrap() * y; + let mut yz = yz.into_shape_with_order(len).unwrap(); + + if let Some(off) = offset { + yz += off; + } + + yz *= y; yz.mapv_inplace(log_logistic); -yz.sum() + F::cast(0.5) * alpha * params.dot(¶ms) } @@ -480,12 +498,19 @@ fn logistic_grad>( y: &Array1, alpha: F, w: &Array1, + offset: Option<&Array1>, ) -> Array1 { let n_features = x.shape()[1]; let (params, intercept) = convert_params(n_features, w); let yz = x.dot(¶ms.into_shape_with_order((params.len(), 1)).unwrap()) + intercept; let len = yz.len(); - let mut yz = yz.into_shape_with_order(len).unwrap() * y; + let mut yz = yz.into_shape_with_order(len).unwrap(); + + if let Some(off) = offset { + yz += off; + } + + yz *= y; yz.mapv_inplace(logistic); yz -= F::one(); yz *= y; @@ -766,6 +791,7 @@ struct LogisticRegressionProblem<'a, F: Float, A: Data, D: Dimension> x: &'a ArrayBase, target: Array, alpha: F, + offset: Option>, } type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>; @@ -778,7 +804,7 @@ impl> CostFunction for LogisticRegressionProblem1<'_ /// Apply the cost function to a parameter `p` fn cost(&self, p: &Self::Param) -> std::result::Result { let w = p.as_array(); - let cost = logistic_loss(self.x, &self.target, self.alpha, w); + let cost = logistic_loss(self.x, &self.target, self.alpha, w, self.offset.as_ref()); Ok(cost) } } @@ -790,7 +816,13 @@ impl> Gradient for LogisticRegressionProblem1<'_, F, /// Compute the gradient at parameter `p`. fn gradient(&self, p: &Self::Param) -> std::result::Result { let w = p.as_array(); - let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w)); + let grad = ArgminParam(logistic_grad( + self.x, + &self.target, + self.alpha, + w, + self.offset.as_ref(), + )); Ok(grad) } } @@ -906,7 +938,7 @@ mod test { .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha))) .zip(&expecteds) { - assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp); + assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w, None), *exp); } } @@ -967,7 +999,7 @@ mod test { .flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha))) .zip(&expecteds) { - let actual = logistic_grad(&x, &y, alpha, w); + let actual = logistic_grad(&x, &y, alpha, w, None); assert!(actual.abs_diff_eq(exp, 1e-8)); } } @@ -1390,4 +1422,57 @@ mod test { assert_abs_diff_eq!(model1.intercept(), model2.intercept()); assert!(model1.params().abs_diff_eq(model2.params(), 1e-6)); } + + #[test] + fn rejects_mismatched_offset_length() { + let log_reg = LogisticRegression::default().offset(array![1.0, 2.0, 3.0]); + let x = array![[-1.0], [-0.01], [0.01], [1.0]]; + let y = array![0, 0, 1, 1]; + let res = log_reg.fit(&Dataset::new(x, y)); + assert!(matches!( + res.unwrap_err(), + Error::OffsetLengthMismatch { + offset_len: 3, + n_samples: 4, + } + )); + } + + #[test] + fn zero_offset_same_as_no_offset() { + let x = array![[-1.0], [-0.01], [0.01], [1.0]]; + let y = array![0, 0, 1, 1]; + + let model_none = LogisticRegression::default() + .fit(&Dataset::new(x.clone(), y.clone())) + .unwrap(); + + let model_zero = LogisticRegression::default() + .offset(array![0.0, 0.0, 0.0, 0.0]) + .fit(&Dataset::new(x, y)) + .unwrap(); + + assert_abs_diff_eq!(model_none.intercept(), model_zero.intercept()); + assert!(model_none.params().abs_diff_eq(model_zero.params(), 1e-6)); + } + + #[test] + fn offset_changes_model() { + let x = array![[-1.0], [-0.01], [0.01], [1.0]]; + let y = array![0, 0, 1, 1]; + + let model_none = LogisticRegression::default() + .fit(&Dataset::new(x.clone(), y.clone())) + .unwrap(); + + let model_offset = LogisticRegression::default() + .offset(array![1.0, 1.0, -1.0, -1.0]) + .fit(&Dataset::new(x, y)) + .unwrap(); + + assert!( + !model_none.params().abs_diff_eq(model_offset.params(), 1e-3), + "Offset should change the learned parameters" + ); + } }