@@ -9,10 +9,7 @@ use linfa::{
99
1010#[ cfg( not( feature = "blas" ) ) ]
1111use linfa_linalg:: triangular:: { SolveTriangularInplace , UPLO } ;
12- use ndarray:: {
13- Array , Array1 , Array2 , ArrayBase , ArrayView , ArrayView1 , ArrayView2 , Axis , CowArray , Data ,
14- Dimension , Ix2 , NewAxis , RemoveAxis , s,
15- } ;
12+ use ndarray:: { Array1 , Array2 , ArrayBase , ArrayView1 , ArrayView2 , Axis , Data , Ix2 , NewAxis , s} ;
1613#[ cfg( feature = "blas" ) ]
1714use ndarray_linalg:: { Diag , Lapack , SolveTriangularInplace , UPLO , layout:: MatrixLayout } ;
1815use ndarray_stats:: QuantileExt ;
3128 /// The feature matrix `x` must have shape `(n_samples, n_features)`
3229 /// The target variable `y` must have shape `(n_samples)`
3330 ///
34- /// Returns a `FittedLARS ` object which contains the fitted
31+ /// Returns a `LARS ` object which contains the fitted
3532 /// parameters and can be used to `predict` values of the target variable
3633 /// for new feature values.
3734 fn fit (
5047 F :: zero ( ) ,
5148 ) ;
5249
53- let intercept = intercept. into_scalar ( ) ;
50+ // let intercept = intercept.into_scalar();
5451
5552 let hyperplane = coef_path. slice ( s ! [ .., -1 ] ) . to_owned ( ) ;
5653
@@ -67,33 +64,23 @@ where
6764
6865/// Compute the intercept as the mean of `y` along each column and center `y` if an intercept
6966/// should be used, use 0 as intercept and leave `y` unchanged otherwise.
70- /// If `y` is 2D, mean is 1D and center is 2D. If `y` is 1D, mean is a number and center is 1D.
71- fn compute_intercept < F : Float , I : RemoveAxis > (
72- with_intercept : bool ,
73- y : ArrayView < F , I > ,
74- ) -> ( Array < F , I :: Smaller > , CowArray < F , I > )
75- where
76- I :: Smaller : Dimension < Larger = I > ,
77- {
67+ fn compute_intercept < F : Float > ( with_intercept : bool , y : ArrayView1 < F > ) -> ( F , Array1 < F > ) {
7868 if with_intercept {
7969 let y_mean = y
80- // Take the mean of each column (1D array counts as 1 column)
81- . mean_axis ( Axis ( 0 ) )
70+ // Take the mean of y
71+ . mean ( )
8272 . expect ( "Axis 0 length of 0" ) ;
83- // Subtract y_mean from each "row" of y
84- let y_centered = & y - & y_mean. view ( ) . insert_axis ( Axis ( 0 ) ) ;
85- ( y_mean, y_centered. into ( ) )
73+ // Subtract y_mean from each element of y
74+ let y_centered = & y - y_mean;
75+ ( y_mean, y_centered)
8676 } else {
87- ( Array :: zeros ( y . raw_dim ( ) . remove_axis ( Axis ( 0 ) ) ) , y. into ( ) )
77+ ( F :: zero ( ) , y. to_owned ( ) )
8878 }
8979}
9080
91- /// Compute Least Angle Regression using LARS algorithm
92- ///
93- /// References
94- /// * ["Least Angle Regression", Efron et al.](http://statweb.stanford.edu/~tibs/ftp/lars.pdf)
95- /// * [Wikipedia entry on the Least-angle regression](https://en.wikipedia.org/wiki/Least-angle_regression)
96- /// * [Wikipedia entry on the Lasso](https://en.wikipedia.org/wiki/Lasso_(statistics))
81+ /// Compute Least Angle Regression using LARS algorithm.
82+ /// Based on scikit-learn’s `lars_path` algorithm.
83+ /// https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.lars_path.html#sklearn.linear_model.lars_path
9784///
9885/// returns alphas, active, coef_path, n_iter
9986fn lars_path < F : Float > (
@@ -400,6 +387,7 @@ fn lars_path<F: Float>(
400387 let coefs_t = coefs_trimmed. t ( ) . to_owned ( ) ;
401388 ( alphas_trimmed, active, coefs_t, n_iter)
402389}
390+
403391/// Solves a linear system `A * x = b` using a Cholesky factorization.
404392///
405393/// - When compiled with the `blas` feature:
0 commit comments