@@ -174,6 +174,7 @@ impl<F: Float, D: Dimension> LogisticRegressionValidParams<F, D> {
174174 x,
175175 target,
176176 alpha : self . alpha ,
177+ offset : self . offset . clone ( ) ,
177178 }
178179 }
179180
@@ -232,6 +233,16 @@ impl<C: Ord + Clone, F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = C>>
232233 let ( x, y) = ( dataset. records ( ) , dataset. targets ( ) ) ;
233234 let ( labels, target) = label_classes ( y) ?;
234235 self . validate_data ( x, & target) ?;
236+
237+ if let Some ( ref offset) = self . offset {
238+ if offset. len ( ) != x. nrows ( ) {
239+ return Err ( Error :: OffsetLengthMismatch {
240+ offset_len : offset. len ( ) ,
241+ n_samples : x. nrows ( ) ,
242+ } ) ;
243+ }
244+ }
245+
235246 let problem = self . setup_problem ( x, target) ;
236247 let solver = self . setup_solver ( ) ;
237248 let init_params = self . setup_init_params ( x. ncols ( ) ) ;
@@ -464,12 +475,19 @@ fn logistic_loss<F: Float, A: Data<Elem = F>>(
464475 y : & Array1 < F > ,
465476 alpha : F ,
466477 w : & Array1 < F > ,
478+ offset : Option < & Array1 < F > > ,
467479) -> F {
468480 let n_features = x. shape ( ) [ 1 ] ;
469481 let ( params, intercept) = convert_params ( n_features, w) ;
470482 let yz = x. dot ( & params. into_shape_with_order ( ( params. len ( ) , 1 ) ) . unwrap ( ) ) + intercept;
471483 let len = yz. len ( ) ;
472- let mut yz = yz. into_shape_with_order ( len) . unwrap ( ) * y;
484+ let mut yz = yz. into_shape_with_order ( len) . unwrap ( ) ;
485+
486+ if let Some ( off) = offset {
487+ yz += off;
488+ }
489+
490+ yz *= y;
473491 yz. mapv_inplace ( log_logistic) ;
474492 -yz. sum ( ) + F :: cast ( 0.5 ) * alpha * params. dot ( & params)
475493}
@@ -480,12 +498,19 @@ fn logistic_grad<F: Float, A: Data<Elem = F>>(
480498 y : & Array1 < F > ,
481499 alpha : F ,
482500 w : & Array1 < F > ,
501+ offset : Option < & Array1 < F > > ,
483502) -> Array1 < F > {
484503 let n_features = x. shape ( ) [ 1 ] ;
485504 let ( params, intercept) = convert_params ( n_features, w) ;
486505 let yz = x. dot ( & params. into_shape_with_order ( ( params. len ( ) , 1 ) ) . unwrap ( ) ) + intercept;
487506 let len = yz. len ( ) ;
488- let mut yz = yz. into_shape_with_order ( len) . unwrap ( ) * y;
507+ let mut yz = yz. into_shape_with_order ( len) . unwrap ( ) ;
508+
509+ if let Some ( off) = offset {
510+ yz += off;
511+ }
512+
513+ yz *= y;
489514 yz. mapv_inplace ( logistic) ;
490515 yz -= F :: one ( ) ;
491516 yz *= y;
@@ -766,6 +791,7 @@ struct LogisticRegressionProblem<'a, F: Float, A: Data<Elem = F>, D: Dimension>
766791 x : & ' a ArrayBase < A , Ix2 > ,
767792 target : Array < F , D > ,
768793 alpha : F ,
794+ offset : Option < Array1 < F > > ,
769795}
770796
771797type LogisticRegressionProblem1 < ' a , F , A > = LogisticRegressionProblem < ' a , F , A , Ix1 > ;
@@ -778,7 +804,7 @@ impl<F: Float, A: Data<Elem = F>> CostFunction for LogisticRegressionProblem1<'_
778804 /// Apply the cost function to a parameter `p`
779805 fn cost ( & self , p : & Self :: Param ) -> std:: result:: Result < Self :: Output , argmin:: core:: Error > {
780806 let w = p. as_array ( ) ;
781- let cost = logistic_loss ( self . x , & self . target , self . alpha , w) ;
807+ let cost = logistic_loss ( self . x , & self . target , self . alpha , w, self . offset . as_ref ( ) ) ;
782808 Ok ( cost)
783809 }
784810}
@@ -790,7 +816,13 @@ impl<F: Float, A: Data<Elem = F>> Gradient for LogisticRegressionProblem1<'_, F,
790816 /// Compute the gradient at parameter `p`.
791817 fn gradient ( & self , p : & Self :: Param ) -> std:: result:: Result < Self :: Param , argmin:: core:: Error > {
792818 let w = p. as_array ( ) ;
793- let grad = ArgminParam ( logistic_grad ( self . x , & self . target , self . alpha , w) ) ;
819+ let grad = ArgminParam ( logistic_grad (
820+ self . x ,
821+ & self . target ,
822+ self . alpha ,
823+ w,
824+ self . offset . as_ref ( ) ,
825+ ) ) ;
794826 Ok ( grad)
795827 }
796828}
@@ -906,7 +938,7 @@ mod test {
906938 . flat_map ( |w| alphas. iter ( ) . map ( move |& alpha| ( w, alpha) ) )
907939 . zip ( & expecteds)
908940 {
909- assert_abs_diff_eq ! ( logistic_loss( & x, & y, alpha, w) , * exp) ;
941+ assert_abs_diff_eq ! ( logistic_loss( & x, & y, alpha, w, None ) , * exp) ;
910942 }
911943 }
912944
@@ -967,7 +999,7 @@ mod test {
967999 . flat_map ( |w| alphas. iter ( ) . map ( move |& alpha| ( w, alpha) ) )
9681000 . zip ( & expecteds)
9691001 {
970- let actual = logistic_grad ( & x, & y, alpha, w) ;
1002+ let actual = logistic_grad ( & x, & y, alpha, w, None ) ;
9711003 assert ! ( actual. abs_diff_eq( exp, 1e-8 ) ) ;
9721004 }
9731005 }
@@ -1390,4 +1422,57 @@ mod test {
13901422 assert_abs_diff_eq ! ( model1. intercept( ) , model2. intercept( ) ) ;
13911423 assert ! ( model1. params( ) . abs_diff_eq( model2. params( ) , 1e-6 ) ) ;
13921424 }
1425+
1426+ #[ test]
1427+ fn rejects_mismatched_offset_length ( ) {
1428+ let log_reg = LogisticRegression :: default ( ) . offset ( array ! [ 1.0 , 2.0 , 3.0 ] ) ;
1429+ let x = array ! [ [ -1.0 ] , [ -0.01 ] , [ 0.01 ] , [ 1.0 ] ] ;
1430+ let y = array ! [ 0 , 0 , 1 , 1 ] ;
1431+ let res = log_reg. fit ( & Dataset :: new ( x, y) ) ;
1432+ assert ! ( matches!(
1433+ res. unwrap_err( ) ,
1434+ Error :: OffsetLengthMismatch {
1435+ offset_len: 3 ,
1436+ n_samples: 4 ,
1437+ }
1438+ ) ) ;
1439+ }
1440+
1441+ #[ test]
1442+ fn zero_offset_same_as_no_offset ( ) {
1443+ let x = array ! [ [ -1.0 ] , [ -0.01 ] , [ 0.01 ] , [ 1.0 ] ] ;
1444+ let y = array ! [ 0 , 0 , 1 , 1 ] ;
1445+
1446+ let model_none = LogisticRegression :: default ( )
1447+ . fit ( & Dataset :: new ( x. clone ( ) , y. clone ( ) ) )
1448+ . unwrap ( ) ;
1449+
1450+ let model_zero = LogisticRegression :: default ( )
1451+ . offset ( array ! [ 0.0 , 0.0 , 0.0 , 0.0 ] )
1452+ . fit ( & Dataset :: new ( x, y) )
1453+ . unwrap ( ) ;
1454+
1455+ assert_abs_diff_eq ! ( model_none. intercept( ) , model_zero. intercept( ) ) ;
1456+ assert ! ( model_none. params( ) . abs_diff_eq( model_zero. params( ) , 1e-6 ) ) ;
1457+ }
1458+
1459+ #[ test]
1460+ fn offset_changes_model ( ) {
1461+ let x = array ! [ [ -1.0 ] , [ -0.01 ] , [ 0.01 ] , [ 1.0 ] ] ;
1462+ let y = array ! [ 0 , 0 , 1 , 1 ] ;
1463+
1464+ let model_none = LogisticRegression :: default ( )
1465+ . fit ( & Dataset :: new ( x. clone ( ) , y. clone ( ) ) )
1466+ . unwrap ( ) ;
1467+
1468+ let model_offset = LogisticRegression :: default ( )
1469+ . offset ( array ! [ 1.0 , 1.0 , -1.0 , -1.0 ] )
1470+ . fit ( & Dataset :: new ( x, y) )
1471+ . unwrap ( ) ;
1472+
1473+ assert ! (
1474+ !model_none. params( ) . abs_diff_eq( model_offset. params( ) , 1e-3 ) ,
1475+ "Offset should change the learned parameters"
1476+ ) ;
1477+ }
13931478}
0 commit comments