Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

We define a number of functions on [linear operators](./operators.md).

## Transformations

These functions transform an operator to a new one (e.g. representing its inverse or column-space projection).

::: lineax.invert

---

::: lineax.project

## Computational changes

These do not change the mathematical meaning of the operator; they simply change how it is stored computationally. (E.g. to materialise the whole operator.)
Expand Down
4 changes: 1 addition & 3 deletions docs/api/linear_solve.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,4 @@ This is the main entry point.

## invert

A convenience function for obtaining the inverse of an operator as a [`lineax.FunctionLinearOperator`][].

::: lineax.invert
A convenience function for obtaining the inverse of an operator as a [`lineax.FunctionLinearOperator`][]; see [`lineax.invert`][].
262 changes: 262 additions & 0 deletions docs/examples/variable_projection.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
tridiagonal as tridiagonal,
TridiagonalLinearOperator as TridiagonalLinearOperator,
)
from ._project import project as project
from ._solution import RESULTS as RESULTS, Solution as Solution
from ._solve import (
AbstractLinearSolver as AbstractLinearSolver,
AutoLinearSolver as AutoLinearSolver,
invert as invert,
linear_solve as linear_solve,
projection_mv as projection_mv,
)
from ._solver import (
BiCGStab as BiCGStab,
Expand All @@ -76,6 +78,8 @@
MaxRankTag as MaxRankTag,
negative_semidefinite_tag as negative_semidefinite_tag,
positive_semidefinite_tag as positive_semidefinite_tag,
project_tags as project_tags,
project_tags_rules as project_tags_rules,
symmetric_tag as symmetric_tag,
tags_from_checks as tags_from_checks,
transpose_tags as transpose_tags,
Expand Down
11 changes: 11 additions & 0 deletions lineax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import equinox as eqx
import jax
import jax.core
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore
Expand All @@ -27,6 +28,16 @@ def tree_where(
return jtu.tree_map(keep, true, false)


def to_shapedarray(x):
"""Convert a `jax.ShapeDtypeStruct` leaf to a `jax.core.ShapedArray` (the abstract
value a primitive's `abstract_eval` rule must return); pass other leaves through.
"""
if isinstance(x, jax.ShapeDtypeStruct):
return jax.core.ShapedArray(x.shape, x.dtype)
else:
return x


def resolve_rcond(rcond, n, m, dtype):
if rcond is None:
# This `2 *` is a heuristic: I have seen very rare failures without it, in ways
Expand Down
Loading
Loading