Skip to content

Commit c971f0e

Browse files
committed
feat: add offset support to binary logistic regression
1 parent 12c6c73 commit c971f0e

3 files changed

Lines changed: 108 additions & 8 deletions

File tree

algorithms/linfa-logistic/src/error.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@ pub enum Error {
1919
InitialParameterFeaturesMismatch { rows: usize, n_features: usize },
2020
#[error("Columns of initial parameter ({cols}) must be the same as the number of classes ({n_classes})")]
2121
InitialParameterClassesMismatch { cols: usize, n_classes: usize },
22-
2322
#[error("gradient_tolerance must be a positive, finite number")]
2423
InvalidGradientTolerance,
2524
#[error("alpha must be a positive, finite number")]
2625
InvalidAlpha,
2726
#[error("Initial parameters must be finite")]
2827
InvalidInitialParameters,
28+
#[error("Offset must be finite")]
29+
InvalidOffset,
30+
#[error("Offset length ({offset_len}) must match the number of samples ({n_samples})")]
31+
OffsetLengthMismatch { offset_len: usize, n_samples: usize },
2932
}

algorithms/linfa-logistic/src/hyperparams.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use linfa::ParamGuard;
2-
use ndarray::{Array, Dimension};
2+
use ndarray::{Array, Array1, Dimension};
33

44
use crate::error::Error;
55
use crate::float::Float;
@@ -29,6 +29,7 @@ pub struct LogisticRegressionValidParams<F: Float, D: Dimension> {
2929
pub(crate) max_iterations: u64,
3030
pub(crate) gradient_tolerance: F,
3131
pub(crate) initial_params: Option<Array<F, D>>,
32+
pub(crate) offset: Option<Array1<F>>,
3233
}
3334

3435
impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
@@ -47,6 +48,11 @@ impl<F: Float, D: Dimension> ParamGuard for LogisticRegressionParams<F, D> {
4748
return Err(Error::InvalidInitialParameters);
4849
}
4950
}
51+
if let Some(ref offset) = self.0.offset {
52+
if offset.iter().any(|o| !o.is_finite()) {
53+
return Err(Error::InvalidOffset);
54+
}
55+
}
5056
Ok(&self.0)
5157
}
5258

@@ -65,6 +71,7 @@ impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
6571
max_iterations: 100,
6672
gradient_tolerance: F::cast(1e-4),
6773
initial_params: None,
74+
offset: None,
6875
})
6976
}
7077

@@ -104,4 +111,9 @@ impl<F: Float, D: Dimension> LogisticRegressionParams<F, D> {
104111
self.0.initial_params = Some(params);
105112
self
106113
}
114+
115+
pub fn offset(mut self, offset: Array1<F>) -> Self {
116+
self.0.offset = Some(offset);
117+
self
118+
}
107119
}

algorithms/linfa-logistic/src/lib.rs

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
174174
x,
175175
target,
176176
alpha: self.alpha,
177+
offset: self.offset.clone(),
177178
}
178179
}
179180

@@ -232,6 +233,16 @@ impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
232233
let (x, y) = (dataset.records(), dataset.targets());
233234
let (labels, target) = label_classes(y)?;
234235
self.validate_data(x, &target)?;
236+
237+
if let Some(ref offset) = self.offset {
238+
if offset.len() != x.nrows() {
239+
return Err(Error::OffsetLengthMismatch {
240+
offset_len: offset.len(),
241+
n_samples: x.nrows(),
242+
});
243+
}
244+
}
245+
235246
let problem = self.setup_problem(x, target);
236247
let solver = self.setup_solver();
237248
let init_params = self.setup_init_params(x.ncols());
@@ -464,12 +475,19 @@ fn logistic_loss<F: Float, A: Data<Elem = F>>(
464475
y: &Array1<F>,
465476
alpha: F,
466477
w: &Array1<F>,
478+
offset: Option<&Array1<F>>,
467479
) -> F {
468480
let n_features = x.shape()[1];
469481
let (params, intercept) = convert_params(n_features, w);
470482
let yz = x.dot(&params.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
471483
let len = yz.len();
472-
let mut yz = yz.into_shape_with_order(len).unwrap() * y;
484+
let mut yz = yz.into_shape_with_order(len).unwrap();
485+
486+
if let Some(off) = offset {
487+
yz += off;
488+
}
489+
490+
yz *= y;
473491
yz.mapv_inplace(log_logistic);
474492
-yz.sum() + F::cast(0.5) * alpha * params.dot(&params)
475493
}
@@ -480,12 +498,19 @@ fn logistic_grad<F: Float, A: Data<Elem = F>>(
480498
y: &Array1<F>,
481499
alpha: F,
482500
w: &Array1<F>,
501+
offset: Option<&Array1<F>>,
483502
) -> Array1<F> {
484503
let n_features = x.shape()[1];
485504
let (params, intercept) = convert_params(n_features, w);
486505
let yz = x.dot(&params.into_shape_with_order((params.len(), 1)).unwrap()) + intercept;
487506
let len = yz.len();
488-
let mut yz = yz.into_shape_with_order(len).unwrap() * y;
507+
let mut yz = yz.into_shape_with_order(len).unwrap();
508+
509+
if let Some(off) = offset {
510+
yz += off;
511+
}
512+
513+
yz *= y;
489514
yz.mapv_inplace(logistic);
490515
yz -= F::one();
491516
yz *= y;
@@ -766,6 +791,7 @@ struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension>
766791
x: &'a ArrayBase<A, Ix2>,
767792
target: Array<F, D>,
768793
alpha: F,
794+
offset: Option<Array1<F>>,
769795
}
770796

771797
type LogisticRegressionProblem1<'a, F, A> = LogisticRegressionProblem<'a, F, A, Ix1>;
@@ -778,7 +804,7 @@ impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_
778804
/// Apply the cost function to a parameter `p`
779805
fn cost(&self, p: &Self::Param) -> std::result::Result<Self::Output, argmin::core::Error> {
780806
let w = p.as_array();
781-
let cost = logistic_loss(self.x, &self.target, self.alpha, w);
807+
let cost = logistic_loss(self.x, &self.target, self.alpha, w, self.offset.as_ref());
782808
Ok(cost)
783809
}
784810
}
@@ -790,7 +816,13 @@ impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F,
790816
/// Compute the gradient at parameter `p`.
791817
fn gradient(&self, p: &Self::Param) -> std::result::Result<Self::Param, argmin::core::Error> {
792818
let w = p.as_array();
793-
let grad = ArgminParam(logistic_grad(self.x, &self.target, self.alpha, w));
819+
let grad = ArgminParam(logistic_grad(
820+
self.x,
821+
&self.target,
822+
self.alpha,
823+
w,
824+
self.offset.as_ref(),
825+
));
794826
Ok(grad)
795827
}
796828
}
@@ -906,7 +938,7 @@ mod test {
906938
.flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
907939
.zip(&expecteds)
908940
{
909-
assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w), *exp);
941+
assert_abs_diff_eq!(logistic_loss(&x, &y, alpha, w, None), *exp);
910942
}
911943
}
912944

@@ -967,7 +999,7 @@ mod test {
967999
.flat_map(|w| alphas.iter().map(move |&alpha| (w, alpha)))
9681000
.zip(&expecteds)
9691001
{
970-
let actual = logistic_grad(&x, &y, alpha, w);
1002+
let actual = logistic_grad(&x, &y, alpha, w, None);
9711003
assert!(actual.abs_diff_eq(exp, 1e-8));
9721004
}
9731005
}
@@ -1390,4 +1422,57 @@ mod test {
13901422
assert_abs_diff_eq!(model1.intercept(), model2.intercept());
13911423
assert!(model1.params().abs_diff_eq(model2.params(), 1e-6));
13921424
}
1425+
1426+
#[test]
1427+
fn rejects_mismatched_offset_length() {
1428+
let log_reg = LogisticRegression::default().offset(array![1.0, 2.0, 3.0]);
1429+
let x = array![[-1.0], [-0.01], [0.01], [1.0]];
1430+
let y = array![0, 0, 1, 1];
1431+
let res = log_reg.fit(&Dataset::new(x, y));
1432+
assert!(matches!(
1433+
res.unwrap_err(),
1434+
Error::OffsetLengthMismatch {
1435+
offset_len: 3,
1436+
n_samples: 4,
1437+
}
1438+
));
1439+
}
1440+
1441+
#[test]
1442+
fn zero_offset_same_as_no_offset() {
1443+
let x = array![[-1.0], [-0.01], [0.01], [1.0]];
1444+
let y = array![0, 0, 1, 1];
1445+
1446+
let model_none = LogisticRegression::default()
1447+
.fit(&Dataset::new(x.clone(), y.clone()))
1448+
.unwrap();
1449+
1450+
let model_zero = LogisticRegression::default()
1451+
.offset(array![0.0, 0.0, 0.0, 0.0])
1452+
.fit(&Dataset::new(x, y))
1453+
.unwrap();
1454+
1455+
assert_abs_diff_eq!(model_none.intercept(), model_zero.intercept());
1456+
assert!(model_none.params().abs_diff_eq(model_zero.params(), 1e-6));
1457+
}
1458+
1459+
#[test]
1460+
fn offset_changes_model() {
1461+
let x = array![[-1.0], [-0.01], [0.01], [1.0]];
1462+
let y = array![0, 0, 1, 1];
1463+
1464+
let model_none = LogisticRegression::default()
1465+
.fit(&Dataset::new(x.clone(), y.clone()))
1466+
.unwrap();
1467+
1468+
let model_offset = LogisticRegression::default()
1469+
.offset(array![1.0, 1.0, -1.0, -1.0])
1470+
.fit(&Dataset::new(x, y))
1471+
.unwrap();
1472+
1473+
assert!(
1474+
!model_none.params().abs_diff_eq(model_offset.params(), 1e-3),
1475+
"Offset should change the learned parameters"
1476+
);
1477+
}
13931478
}

0 commit comments

Comments
 (0)