Skip to content

Commit 57078d0

Browse files
committed
improve documentation
1 parent 85c28ca commit 57078d0

5 files changed

Lines changed: 37 additions & 41 deletions

File tree

algorithms/linfa-lars/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ ndarray = { version = "0.16", features = ["approx"] }
3636
ndarray-linalg = { version = "0.17", optional = true }
3737
ndarray-stats = "0.6"
3838

39-
thiserror = "1.0"
39+
thiserror = "2.0"
4040

4141
[dev-dependencies]
4242
linfa-datasets = { version = "0.8.0", path = "../../datasets", features = [
4343
"diabetes",
44+
"linnerud"
4445
] }
4546
approx = "0.5"
4647
ndarray-rand = "0.15"

algorithms/linfa-lars/examples/lars.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ fn main() {
66
let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90);
77

88
let model = Lars::params()
9-
.fit_intercept(false)
9+
.fit_intercept(true)
1010
.verbose(2)
1111
.fit(&train)
1212
.unwrap();

algorithms/linfa-lars/src/algorithm.rs

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ use linfa::{
99

1010
#[cfg(not(feature = "blas"))]
1111
use linfa_linalg::triangular::{SolveTriangularInplace, UPLO};
12-
use ndarray::{
13-
Array, Array1, Array2, ArrayBase, ArrayView, ArrayView1, ArrayView2, Axis, CowArray, Data,
14-
Dimension, Ix2, NewAxis, RemoveAxis, s,
15-
};
12+
use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2, NewAxis, s};
1613
#[cfg(feature = "blas")]
1714
use ndarray_linalg::{Diag, Lapack, SolveTriangularInplace, UPLO, layout::MatrixLayout};
1815
use ndarray_stats::QuantileExt;
@@ -31,7 +28,7 @@ where
3128
/// The feature matrix `x` must have shape `(n_samples, n_features)`
3229
/// The target variable `y` must have shape `(n_samples)`
3330
///
34-
/// Returns a `FittedLARS` object which contains the fitted
31+
/// Returns a `LARS` object which contains the fitted
3532
/// parameters and can be used to `predict` values of the target variable
3633
/// for new feature values.
3734
fn fit(
@@ -50,7 +47,7 @@ where
5047
F::zero(),
5148
);
5249

53-
let intercept = intercept.into_scalar();
50+
// let intercept = intercept.into_scalar();
5451

5552
let hyperplane = coef_path.slice(s![.., -1]).to_owned();
5653

@@ -67,33 +64,23 @@ where
6764

6865
/// Compute the intercept as the mean of `y` along each column and center `y` if an intercept
6966
/// should be used, use 0 as intercept and leave `y` unchanged otherwise.
70-
/// If `y` is 2D, mean is 1D and center is 2D. If `y` is 1D, mean is a number and center is 1D.
71-
fn compute_intercept<F: Float, I: RemoveAxis>(
72-
with_intercept: bool,
73-
y: ArrayView<F, I>,
74-
) -> (Array<F, I::Smaller>, CowArray<F, I>)
75-
where
76-
I::Smaller: Dimension<Larger = I>,
77-
{
67+
fn compute_intercept<F: Float>(with_intercept: bool, y: ArrayView1<F>) -> (F, Array1<F>) {
7868
if with_intercept {
7969
let y_mean = y
80-
// Take the mean of each column (1D array counts as 1 column)
81-
.mean_axis(Axis(0))
70+
// Take the mean of y
71+
.mean()
8272
.expect("Axis 0 length of 0");
83-
// Subtract y_mean from each "row" of y
84-
let y_centered = &y - &y_mean.view().insert_axis(Axis(0));
85-
(y_mean, y_centered.into())
73+
// Subtract y_mean from each element of y
74+
let y_centered = &y - y_mean;
75+
(y_mean, y_centered)
8676
} else {
87-
(Array::zeros(y.raw_dim().remove_axis(Axis(0))), y.into())
77+
(F::zero(), y.to_owned())
8878
}
8979
}
9080

91-
/// Compute Least Angle Regression using LARS algorithm
92-
///
93-
/// References
94-
/// * ["Least Angle Regression", Efron et al.](http://statweb.stanford.edu/~tibs/ftp/lars.pdf)
95-
/// * [Wikipedia entry on the Least-angle regression](https://en.wikipedia.org/wiki/Least-angle_regression)
96-
/// * [Wikipedia entry on the Lasso](https://en.wikipedia.org/wiki/Lasso_(statistics))
81+
/// Compute Least Angle Regression using LARS algorithm.
82+
/// Based on scikit-learn’s `lars_path` algorithm.
83+
/// https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.lars_path.html#sklearn.linear_model.lars_path
9784
///
9885
/// returns alphas, active, coef_path, n_iter
9986
fn lars_path<F: Float>(
@@ -400,6 +387,7 @@ fn lars_path<F: Float>(
400387
let coefs_t = coefs_trimmed.t().to_owned();
401388
(alphas_trimmed, active, coefs_t, n_iter)
402389
}
390+
403391
/// Solves a linear system `A * x = b` using a Cholesky factorization.
404392
///
405393
/// - When compiled with the `blas` feature:

algorithms/linfa-lars/src/hyperparams.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub struct LarsValidParams<F> {
1313
eps: F,
1414
verbose: usize,
1515
}
16-
// #[derive(Default)]
16+
1717
pub struct LarsParams<F>(LarsValidParams<F>);
1818

1919
impl<F: Float> LarsValidParams<F> {

algorithms/linfa-lars/src/lib.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
//! # Least angle regression a.k.a. LAR
2+
//!
3+
//! This struct contains the parameters of a fitted LARS model. This includes the seperating
4+
//! hyperplane, (optionally) intercept, alphas (Maximum of covariances (in absolute value) at each iteration),
5+
//! Indices of active variables at the end of the path,
6+
//!
7+
//!
8+
//! LARS is similar to forward stepwise regression.
9+
//! At each step, it finds the feature most correlated with the target.
10+
//! When there are multiple features having equal correlation, instead of continuing along the same feature,
11+
//! it proceeds in a direction equiangular between the features.
12+
//!
13+
//! ## References
14+
//!
15+
//! * ["Least Angle Regression", Efron et al.](https://web.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf)
16+
//! * [Wikipedia entry on the Least-angle regression](https://en.wikipedia.org/wiki/Least-angle_regression)
17+
//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/linear_model.html#least-angle-regression)
18+
119
use linfa::{Float, traits::PredictInplace};
220
use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
321

@@ -13,17 +31,6 @@ mod hyperparams;
1331
derive(Serialize, Deserialize),
1432
serde(crate = "serde_crate")
1533
)]
16-
/// Least angle regression a.k.a. LAR
17-
///
18-
/// This struct contains the parameters of a fitted LARS model. This includes the seperating
19-
/// hyperplane, (optionally) intercept, alphas (Maximum of covariances (in absolute value) at each iteration),
20-
/// Indices of active variables at the end of the path,
21-
///
22-
///
23-
/// LARS is similar to forward stepwise regression.
24-
/// At each step, it finds the feature most correlated with the target.
25-
/// When there are multiple features having equal correlation, instead of continuing along the same feature,
26-
/// it proceeds in a direction equiangular between the features.
2734
#[derive(Debug, Clone)]
2835
pub struct Lars<F> {
2936
hyperplane: Array1<F>,
@@ -38,7 +45,7 @@ impl<F: Float> Lars<F> {
3845
/// Create default Lars hyper parameters
3946
///
4047
/// By default, an intercept will be fitted. To disable fitting an
41-
/// intercept, call `.with_intercept(false)` before calling `.fit()`.
48+
/// intercept, call `.fit_intercept(false)` before calling `.fit()`.
4249
///
4350
/// The feature matrix will not be normalized by default.
4451
pub fn params() -> LarsParams<F> {

0 commit comments

Comments
 (0)