From bdde59cb8f6e4da7fe0c9d8d6a0400b4b05c12de Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Thu, 18 Jun 2026 10:29:47 +0100 Subject: [PATCH 1/4] refactor: move AbstractLinearOperator to separate file to avoid circular imports --- lineax/_solve.py | 239 +++++----------------------------- lineax/_solver/bicgstab.py | 2 +- lineax/_solver/cg.py | 2 +- lineax/_solver/cholesky.py | 2 +- lineax/_solver/diagonal.py | 2 +- lineax/_solver/gmres.py | 4 +- lineax/_solver/lsmr.py | 2 +- lineax/_solver/lu.py | 2 +- lineax/_solver/normal.py | 10 +- lineax/_solver/qr.py | 2 +- lineax/_solver/svd.py | 2 +- lineax/_solver/triangular.py | 2 +- lineax/_solver/tridiagonal.py | 2 +- 13 files changed, 56 insertions(+), 217 deletions(-) diff --git a/lineax/_solve.py b/lineax/_solve.py index 817d710a..f7ed5c71 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import functools as ft -from typing import Any, Generic, TypeAlias, TypeVar +from typing import Any, TypeAlias import equinox as eqx import equinox.internal as eqxi @@ -46,6 +45,16 @@ TangentLinearOperator, ) from ._solution import RESULTS, Solution +from ._solver import ( + Cholesky, + Diagonal, + LU, + QR, + SVD, + Triangular, + Tridiagonal, +) +from ._solver.base import AbstractLinearSolver as AbstractLinearSolver from ._tags import ( invert_tags, tags_from_checks, @@ -331,149 +340,6 @@ def _linear_solve_transpose(inputs, cts_out): # -_SolverState = TypeVar("_SolverState") - - -class AbstractLinearSolver(eqx.Module, Generic[_SolverState]): - """Abstract base class for all linear solvers.""" - - @abc.abstractmethod - def init( - self, operator: AbstractLinearOperator, options: dict[str, Any] - ) -> _SolverState: - """Do any initial computation on just the `operator`. - - For example, an LU solver would compute the LU decomposition of the operator - (and this does not require knowing the vector yet). - - It is common to need to solve the linear system `Ax=b` multiple times in - succession, with the same operator `A` and multiple vectors `b`. This method - improves efficiency by making it possible to re-use the computation performed - on just the operator. - - !!! Example - - ```python - operator = lx.MatrixLinearOperator(...) - vector1 = ... - vector2 = ... - solver = lx.LU() - state = solver.init(operator, options={}) - solution1 = lx.linear_solve(operator, vector1, solver, state=state) - solution2 = lx.linear_solve(operator, vector2, solver, state=state) - ``` - - **Arguments:** - - - `operator`: a linear operator. - - `options`: a dictionary of any extra options that the solver may wish to - accept. - - **Returns:** - - A PyTree of arbitrary Python objects. - """ - - @abc.abstractmethod - def compute( - self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any] - ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - """Solves a linear system. - - **Arguments:** - - - `state`: as returned from [`lineax.AbstractLinearSolver.init`][]. - - `vector`: the vector to solve against. - - `options`: a dictionary of any extra options that the solver may wish to - accept. For example, [`lineax.CG`][] accepts a `preconditioner` option. - - **Returns:** - - A 3-tuple of: - - - The solution to the linear system. - - An integer indicating the success or failure of the solve. This is an integer - which may be converted to a human-readable error message via - `lx.RESULTS[...]`. - - A dictionary of an extra statistics about the solve, e.g. the number of steps - taken. - """ - - @abc.abstractmethod - def transpose( - self, state: _SolverState, options: dict[str, Any] - ) -> tuple[_SolverState, dict[str, Any]]: - """Transposes the result of [`lineax.AbstractLinearSolver.init`][]. - - That is, it should be the case that - ```python - state_transpose, _ = solver.transpose(solver.init(operator, options), options) - state_transpose2 = solver.init(operator.T, options) - ``` - must be identical to each other. - - It is relatively common (in particular when differentiating through a linear - solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it - possible to avoid computing both `solver.init(operator)` and - `solver.init(operator.T)` if one can be cheaply computed from the other. - - **Arguments:** - - - `state`: as returned from `solver.init`. - - `options`: any extra options that were passed to `solve.init`. - - **Returns:** - - A 2-tuple of: - - - The state of the transposed operator. - - The options for the transposed operator. - """ - - @abc.abstractmethod - def conj( - self, state: _SolverState, options: dict[str, Any] - ) -> tuple[_SolverState, dict[str, Any]]: - """Conjugate the result of [`lineax.AbstractLinearSolver.init`][]. - - That is, it should be the case that - ```python - state_conj, _ = solver.conj(solver.init(operator, options), options) - state_conj2 = solver.init(conj(operator), options) - ``` - must be identical to each other. - - **Arguments:** - - - `state`: as returned from `solver.init`. - - `options`: any extra options that were passed to `solve.init`. - - **Returns:** - - A 2-tuple of: - - - The state of the conjugated operator. - - The options for the conjugated operator. - """ - - @abc.abstractmethod - def assume_full_rank(self) -> bool: - """Does this solver assume that all operators are full rank? - - When `False`, a more expensive backward pass is needed to account for - the extra generality. In a custom linear solver, it is always safe to - return False. - - **Arguments:** - - Nothing. - - **Returns:** - - Either `True` or `False`. - """ - - def _check_rank_compat( solver: "AbstractLinearSolver", operator: AbstractLinearOperator ): @@ -490,39 +356,7 @@ def _check_rank_compat( ) -_qr_token = eqxi.str2jax("qr_token") -_diagonal_token = eqxi.str2jax("diagonal_token") -_well_posed_diagonal_token = eqxi.str2jax("well_posed_diagonal_token") -_tridiagonal_token = eqxi.str2jax("tridiagonal_token") -_triangular_token = eqxi.str2jax("triangular_token") -_cholesky_token = eqxi.str2jax("cholesky_token") -_lu_token = eqxi.str2jax("lu_token") -_svd_token = eqxi.str2jax("svd_token") - - -# Ugly delayed import because we have the dependency chain -# linear_solve -> AutoLinearSolver -> {Cholesky,...} -> AbstractLinearSolver -# but we want linear_solver and AbstractLinearSolver in the same file. -def _lookup(token) -> AbstractLinearSolver: - from . import _solver - - # pyright doesn't know that these keys are hashable - _lookup_dict = { - _qr_token: _solver.QR(), # pyright: ignore - _diagonal_token: _solver.Diagonal(), # pyright: ignore - _well_posed_diagonal_token: _solver.Diagonal( # pyright: ignore - well_posed=True - ), - _tridiagonal_token: _solver.Tridiagonal(), # pyright: ignore - _triangular_token: _solver.Triangular(), # pyright: ignore - _cholesky_token: _solver.Cholesky(), # pyright: ignore - _lu_token: _solver.LU(), # pyright: ignore - _svd_token: _solver.SVD(), # pyright: ignore - } - return _lookup_dict[token] - - -_AutoLinearSolverState: TypeAlias = tuple[Any, Any] +_AutoLinearSolverState: TypeAlias = tuple[AbstractLinearSolver, Any] class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): @@ -562,7 +396,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): well_posed: bool | None - def _select_solver(self, operator: AbstractLinearOperator): + def _select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: if self.well_posed is True: if operator.in_size() != operator.out_size(): raise ValueError( @@ -573,41 +407,41 @@ def _select_solver(self, operator: AbstractLinearOperator): "square and nonsingular." ) if is_diagonal(operator): - token = _well_posed_diagonal_token + solver = Diagonal(well_posed=True) elif is_tridiagonal(operator): - token = _tridiagonal_token + solver = Tridiagonal() elif is_lower_triangular(operator) or is_upper_triangular(operator): - token = _triangular_token + solver = Triangular() elif is_positive_semidefinite(operator) or is_negative_semidefinite( operator ): - token = _cholesky_token + solver = Cholesky() else: - token = _lu_token + solver = LU() elif self.well_posed is False: if is_diagonal(operator): - token = _diagonal_token + solver = Diagonal() else: # TODO: use rank-revealing QR instead. - token = _svd_token + solver = SVD() elif self.well_posed is None: if operator.in_size() != operator.out_size(): - token = _qr_token + solver = QR() elif is_diagonal(operator): - token = _diagonal_token + solver = Diagonal() elif is_tridiagonal(operator): - token = _tridiagonal_token + solver = Tridiagonal() elif is_lower_triangular(operator) or is_upper_triangular(operator): - token = _triangular_token + solver = Triangular() elif is_positive_semidefinite(operator) or is_negative_semidefinite( operator ): - token = _cholesky_token + solver = Cholesky() else: - token = _lu_token + solver = LU() else: raise ValueError(f"Invalid value `well_posed={self.well_posed}`.") - return token + return solver def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: """Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to. @@ -620,11 +454,11 @@ def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolve The linear solver that will be used. """ - return _lookup(self._select_solver(operator)) + return self._select_solver(operator) def init(self, operator, options) -> _AutoLinearSolverState: - token = self._select_solver(operator) - return token, _lookup(token).init(operator, options) + solver = self._select_solver(operator) + return solver, solver.init(operator, options) def compute( self, @@ -632,23 +466,20 @@ def compute( vector: PyTree[Array], options: dict[str, Any], ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - token, state = state - solver = _lookup(token) + solver, state = state solution, result, _ = solver.compute(state, vector, options) return solution, result, {} def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]): - token, state = state - solver = _lookup(token) + solver, state = state transpose_state, transpose_options = solver.transpose(state, options) - transpose_state = (token, transpose_state) + transpose_state = (solver, transpose_state) return transpose_state, transpose_options def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): - token, state = state - solver = _lookup(token) + solver, state = state conj_state, conj_options = solver.conj(state, options) - conj_state = (token, conj_state) + conj_state = (solver, conj_state) return conj_state, conj_options def assume_full_rank(self): diff --git a/lineax/_solver/bicgstab.py b/lineax/_solver/bicgstab.py index 31cf241d..65ce7842 100644 --- a/lineax/_solver/bicgstab.py +++ b/lineax/_solver/bicgstab.py @@ -25,7 +25,7 @@ from .._norm import max_norm, tree_dot from .._operator import AbstractLinearOperator, conj, linearise from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import preconditioner_and_y0 diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index eee8c970..cd9655af 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -34,7 +34,7 @@ linearise, ) from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import preconditioner_and_y0 from .normal import Normal diff --git a/lineax/_solver/cholesky.py b/lineax/_solver/cholesky.py index 852ab709..6a352fe2 100644 --- a/lineax/_solver/cholesky.py +++ b/lineax/_solver/cholesky.py @@ -25,7 +25,7 @@ is_positive_semidefinite, ) from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver _CholeskyState: TypeAlias = tuple[Array, eqxi.Static] diff --git a/lineax/_solver/diagonal.py b/lineax/_solver/diagonal.py index 334e96e9..c2716392 100644 --- a/lineax/_solver/diagonal.py +++ b/lineax/_solver/diagonal.py @@ -20,7 +20,7 @@ from .._misc import resolve_rcond from .._operator import AbstractLinearOperator, diagonal, has_unit_diagonal, is_diagonal from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/gmres.py b/lineax/_solver/gmres.py index d5911a06..aabe4bf4 100644 --- a/lineax/_solver/gmres.py +++ b/lineax/_solver/gmres.py @@ -28,7 +28,7 @@ from .._norm import max_norm, two_norm from .._operator import AbstractLinearOperator, conj, linearise, MatrixLinearOperator from .._solution import RESULTS -from .._solve import AbstractLinearSolver, linear_solve +from .base import AbstractLinearSolver from .misc import preconditioner_and_y0 from .qr import QR @@ -300,6 +300,8 @@ def buffers(carry): ) coeff_op_transpose = MatrixLinearOperator(coeff_mat.T) # TODO(raderj): move to a Hessenberg-specific solver + from .._solve import linear_solve + z = linear_solve(coeff_op_transpose, beta_vec, QR(), throw=False).value diff = jtu.tree_map( lambda mat: jnp.tensordot( diff --git a/lineax/_solver/lsmr.py b/lineax/_solver/lsmr.py index 8491da0a..9cbf092b 100644 --- a/lineax/_solver/lsmr.py +++ b/lineax/_solver/lsmr.py @@ -46,7 +46,7 @@ from .._norm import two_norm from .._operator import AbstractLinearOperator, conj, linearise, max_rank from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver _LSMRState: TypeAlias = AbstractLinearOperator diff --git a/lineax/_solver/lu.py b/lineax/_solver/lu.py index 72836003..a7511863 100644 --- a/lineax/_solver/lu.py +++ b/lineax/_solver/lu.py @@ -21,7 +21,7 @@ from .._operator import AbstractLinearOperator, is_diagonal from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/normal.py b/lineax/_solver/normal.py index e46aabdd..021065c5 100644 --- a/lineax/_solver/normal.py +++ b/lineax/_solver/normal.py @@ -18,10 +18,16 @@ import equinox.internal as eqxi from jaxtyping import Array, PyTree -from .._operator import conj, linearise, materialise, TaggedLinearOperator +from .._operator import ( + AbstractLinearOperator, + conj, + linearise, + materialise, + TaggedLinearOperator, +) from .._solution import RESULTS -from .._solve import AbstractLinearOperator, AbstractLinearSolver from .._tags import positive_semidefinite_tag +from .base import AbstractLinearSolver from .cholesky import Cholesky diff --git a/lineax/_solver/qr.py b/lineax/_solver/qr.py index 69e4b42c..c7c57a0b 100644 --- a/lineax/_solver/qr.py +++ b/lineax/_solver/qr.py @@ -21,7 +21,7 @@ from jaxtyping import Array, PyTree from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/svd.py b/lineax/_solver/svd.py index 5b872091..3c4bfeb6 100644 --- a/lineax/_solver/svd.py +++ b/lineax/_solver/svd.py @@ -23,7 +23,7 @@ from .._misc import resolve_rcond from .._operator import AbstractLinearOperator, max_rank from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/triangular.py b/lineax/_solver/triangular.py index 304f21ae..e0d1674a 100644 --- a/lineax/_solver/triangular.py +++ b/lineax/_solver/triangular.py @@ -25,7 +25,7 @@ is_upper_triangular, ) from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, diff --git a/lineax/_solver/tridiagonal.py b/lineax/_solver/tridiagonal.py index 7d83eb7d..e1f72e74 100644 --- a/lineax/_solver/tridiagonal.py +++ b/lineax/_solver/tridiagonal.py @@ -20,7 +20,7 @@ from .._operator import AbstractLinearOperator, is_tridiagonal, tridiagonal from .._solution import RESULTS -from .._solve import AbstractLinearSolver +from .base import AbstractLinearSolver from .misc import ( pack_structures, PackedStructures, From 646a306e67ee4f7f3f84dc89cf7139493f88661c Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Thu, 18 Jun 2026 10:42:22 +0100 Subject: [PATCH 2/4] and auto too --- lineax/__init__.py | 2 +- lineax/_solve.py | 156 +-------------------------------- lineax/_solver/__init__.py | 1 + lineax/_solver/auto.py | 172 +++++++++++++++++++++++++++++++++++++ lineax/_solver/base.py | 165 +++++++++++++++++++++++++++++++++++ 5 files changed, 342 insertions(+), 154 deletions(-) create mode 100644 lineax/_solver/auto.py create mode 100644 lineax/_solver/base.py diff --git a/lineax/__init__.py b/lineax/__init__.py index 4f370dee..f7d0d603 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -49,11 +49,11 @@ from ._solution import RESULTS as RESULTS, Solution as Solution from ._solve import ( AbstractLinearSolver as AbstractLinearSolver, - AutoLinearSolver as AutoLinearSolver, invert as invert, linear_solve as linear_solve, ) from ._solver import ( + AutoLinearSolver as AutoLinearSolver, BiCGStab as BiCGStab, CG as CG, Cholesky as Cholesky, diff --git a/lineax/_solve.py b/lineax/_solve.py index f7ed5c71..cd940b00 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -13,7 +13,7 @@ # limitations under the License. import functools as ft -from typing import Any, TypeAlias +from typing import Any import equinox as eqx import equinox.internal as eqxi @@ -25,7 +25,7 @@ import jax.tree_util as jtu from equinox.internal import ω from jax._src.ad_util import stop_gradient_p -from jaxtyping import Array, ArrayLike, PyTree +from jaxtyping import ArrayLike, PyTree from ._custom_types import sentinel from ._misc import inexact_asarray, strip_weak_dtype @@ -34,26 +34,12 @@ conj, FunctionLinearOperator, IdentityLinearOperator, - is_diagonal, - is_lower_triangular, - is_negative_semidefinite, - is_positive_semidefinite, - is_tridiagonal, - is_upper_triangular, linearise, max_rank, TangentLinearOperator, ) from ._solution import RESULTS, Solution -from ._solver import ( - Cholesky, - Diagonal, - LU, - QR, - SVD, - Triangular, - Tridiagonal, -) +from ._solver import AutoLinearSolver from ._solver.base import AbstractLinearSolver as AbstractLinearSolver from ._tags import ( invert_tags, @@ -356,142 +342,6 @@ def _check_rank_compat( ) -_AutoLinearSolverState: TypeAlias = tuple[AbstractLinearSolver, Any] - - -class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): - """Automatically determines a good linear solver based on the structure of the - operator. - - - If `well_posed=True`: - - If the operator is diagonal, then use [`lineax.Diagonal`][]. - - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. - - If the operator is triangular, then use [`lineax.Triangular`][]. - - If the matrix is positive or negative (semi-)definite, then use - [`lineax.Cholesky`][]. - - Else use [`lineax.LU`][]. - - This is a good choice if you want to be certain that an error is raised for - ill-posed systems. - - - If `well_posed=False`: - - If the operator is diagonal, then use [`lineax.Diagonal`][]. - - Else use [`lineax.SVD`][]. - - This is a good choice if you want to be certain that you can handle ill-posed - systems. - - - If `well_posed=None`: - - If the operator is non-square, then use [`lineax.QR`][]. - - If the operator is diagonal, then use [`lineax.Diagonal`][]. - - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. - - If the operator is triangular, then use [`lineax.Triangular`][]. - - If the matrix is positive or negative (semi-)definite, then use - [`lineax.Cholesky`][]. - - Else, use [`lineax.LU`][]. - - This is a good choice if your primary concern is computational efficiency. It will - handle ill-posed systems as long as it is not computationally expensive to do so. - """ - - well_posed: bool | None - - def _select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: - if self.well_posed is True: - if operator.in_size() != operator.out_size(): - raise ValueError( - "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square " - "operator. If you are trying solve a least-squares problem then " - "you should pass `solver=AutoLinearSolver(well_posed=False)`. By " - "default `lineax.linear_solve` assumes that the operator is " - "square and nonsingular." - ) - if is_diagonal(operator): - solver = Diagonal(well_posed=True) - elif is_tridiagonal(operator): - solver = Tridiagonal() - elif is_lower_triangular(operator) or is_upper_triangular(operator): - solver = Triangular() - elif is_positive_semidefinite(operator) or is_negative_semidefinite( - operator - ): - solver = Cholesky() - else: - solver = LU() - elif self.well_posed is False: - if is_diagonal(operator): - solver = Diagonal() - else: - # TODO: use rank-revealing QR instead. - solver = SVD() - elif self.well_posed is None: - if operator.in_size() != operator.out_size(): - solver = QR() - elif is_diagonal(operator): - solver = Diagonal() - elif is_tridiagonal(operator): - solver = Tridiagonal() - elif is_lower_triangular(operator) or is_upper_triangular(operator): - solver = Triangular() - elif is_positive_semidefinite(operator) or is_negative_semidefinite( - operator - ): - solver = Cholesky() - else: - solver = LU() - else: - raise ValueError(f"Invalid value `well_posed={self.well_posed}`.") - return solver - - def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: - """Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to. - - **Arguments:** - - - `operator`: a linear operator. - - **Returns:** - - The linear solver that will be used. - """ - return self._select_solver(operator) - - def init(self, operator, options) -> _AutoLinearSolverState: - solver = self._select_solver(operator) - return solver, solver.init(operator, options) - - def compute( - self, - state: _AutoLinearSolverState, - vector: PyTree[Array], - options: dict[str, Any], - ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: - solver, state = state - solution, result, _ = solver.compute(state, vector, options) - return solution, result, {} - - def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]): - solver, state = state - transpose_state, transpose_options = solver.transpose(state, options) - transpose_state = (solver, transpose_state) - return transpose_state, transpose_options - - def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): - solver, state = state - conj_state, conj_options = solver.conj(state, options) - conj_state = (solver, conj_state) - return conj_state, conj_options - - def assume_full_rank(self): - return self.well_posed is not False - - -AutoLinearSolver.__init__.__doc__ = """**Arguments:** - -- `well_posed`: whether to only handle well-posed systems or not, as discussed above. -""" - - # TODO(kidger): gmres, bicgstab # TODO(kidger): support auxiliary outputs @eqx.filter_jit diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index 2cee02c5..cbde8261 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .auto import AutoLinearSolver as AutoLinearSolver from .bicgstab import BiCGStab as BiCGStab from .cg import CG as CG, NormalCG as NormalCG from .cholesky import Cholesky as Cholesky diff --git a/lineax/_solver/auto.py b/lineax/_solver/auto.py new file mode 100644 index 00000000..9837db22 --- /dev/null +++ b/lineax/_solver/auto.py @@ -0,0 +1,172 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeAlias + +from jaxtyping import Array, PyTree + +from .._operator import ( + AbstractLinearOperator, + is_diagonal, + is_lower_triangular, + is_negative_semidefinite, + is_positive_semidefinite, + is_tridiagonal, + is_upper_triangular, +) +from .._solution import RESULTS +from .base import AbstractLinearSolver +from .cholesky import Cholesky +from .diagonal import Diagonal +from .lu import LU +from .qr import QR +from .svd import SVD +from .triangular import Triangular +from .tridiagonal import Tridiagonal + + +_AutoLinearSolverState: TypeAlias = tuple[AbstractLinearSolver, Any] + + +class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): + """Automatically determines a good linear solver based on the structure of the + operator. + + - If `well_posed=True`: + - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. + - If the operator is triangular, then use [`lineax.Triangular`][]. + - If the matrix is positive or negative (semi-)definite, then use + [`lineax.Cholesky`][]. + - Else use [`lineax.LU`][]. + + This is a good choice if you want to be certain that an error is raised for + ill-posed systems. + + - If `well_posed=False`: + - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - Else use [`lineax.SVD`][]. + + This is a good choice if you want to be certain that you can handle ill-posed + systems. + + - If `well_posed=None`: + - If the operator is non-square, then use [`lineax.QR`][]. + - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - If the operator is tridiagonal, then use [`lineax.Tridiagonal`][]. + - If the operator is triangular, then use [`lineax.Triangular`][]. + - If the matrix is positive or negative (semi-)definite, then use + [`lineax.Cholesky`][]. + - Else, use [`lineax.LU`][]. + + This is a good choice if your primary concern is computational efficiency. It will + handle ill-posed systems as long as it is not computationally expensive to do so. + """ + + well_posed: bool | None + + def _select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: + if self.well_posed is True: + if operator.in_size() != operator.out_size(): + raise ValueError( + "Cannot use `AutoLinearSolver(well_posed=True)` with a non-square " + "operator. If you are trying solve a least-squares problem then " + "you should pass `solver=AutoLinearSolver(well_posed=False)`. By " + "default `lineax.linear_solve` assumes that the operator is " + "square and nonsingular." + ) + if is_diagonal(operator): + solver = Diagonal(well_posed=True) + elif is_tridiagonal(operator): + solver = Tridiagonal() + elif is_lower_triangular(operator) or is_upper_triangular(operator): + solver = Triangular() + elif is_positive_semidefinite(operator) or is_negative_semidefinite( + operator + ): + solver = Cholesky() + else: + solver = LU() + elif self.well_posed is False: + if is_diagonal(operator): + solver = Diagonal() + else: + # TODO: use rank-revealing QR instead. + solver = SVD() + elif self.well_posed is None: + if operator.in_size() != operator.out_size(): + solver = QR() + elif is_diagonal(operator): + solver = Diagonal() + elif is_tridiagonal(operator): + solver = Tridiagonal() + elif is_lower_triangular(operator) or is_upper_triangular(operator): + solver = Triangular() + elif is_positive_semidefinite(operator) or is_negative_semidefinite( + operator + ): + solver = Cholesky() + else: + solver = LU() + else: + raise ValueError(f"Invalid value `well_posed={self.well_posed}`.") + return solver + + def select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolver: + """Check which solver that [`lineax.AutoLinearSolver`][] will dispatch to. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + The linear solver that will be used. + """ + return self._select_solver(operator) + + def init(self, operator, options) -> _AutoLinearSolverState: + solver = self._select_solver(operator) + return solver, solver.init(operator, options) + + def compute( + self, + state: _AutoLinearSolverState, + vector: PyTree[Array], + options: dict[str, Any], + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + solver, state = state + solution, result, _ = solver.compute(state, vector, options) + return solution, result, {} + + def transpose(self, state: _AutoLinearSolverState, options: dict[str, Any]): + solver, state = state + transpose_state, transpose_options = solver.transpose(state, options) + transpose_state = (solver, transpose_state) + return transpose_state, transpose_options + + def conj(self, state: _AutoLinearSolverState, options: dict[str, Any]): + solver, state = state + conj_state, conj_options = solver.conj(state, options) + conj_state = (solver, conj_state) + return conj_state, conj_options + + def assume_full_rank(self): + return self.well_posed is not False + + +AutoLinearSolver.__init__.__doc__ = """**Arguments:** + +- `well_posed`: whether to only handle well-posed systems or not, as discussed above. +""" diff --git a/lineax/_solver/base.py b/lineax/_solver/base.py new file mode 100644 index 00000000..9984221d --- /dev/null +++ b/lineax/_solver/base.py @@ -0,0 +1,165 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Generic, TypeVar + +import equinox as eqx +from jaxtyping import Array, PyTree + +from .._operator import AbstractLinearOperator +from .._solution import RESULTS + + +_SolverState = TypeVar("_SolverState") + + +class AbstractLinearSolver(eqx.Module, Generic[_SolverState]): + """Abstract base class for all linear solvers.""" + + @abc.abstractmethod + def init( + self, operator: AbstractLinearOperator, options: dict[str, Any] + ) -> _SolverState: + """Do any initial computation on just the `operator`. + + For example, an LU solver would compute the LU decomposition of the operator + (and this does not require knowing the vector yet). + + It is common to need to solve the linear system `Ax=b` multiple times in + succession, with the same operator `A` and multiple vectors `b`. This method + improves efficiency by making it possible to re-use the computation performed + on just the operator. + + !!! Example + + ```python + operator = lx.MatrixLinearOperator(...) + vector1 = ... + vector2 = ... + solver = lx.LU() + state = solver.init(operator, options={}) + solution1 = lx.linear_solve(operator, vector1, solver, state=state) + solution2 = lx.linear_solve(operator, vector2, solver, state=state) + ``` + + **Arguments:** + + - `operator`: a linear operator. + - `options`: a dictionary of any extra options that the solver may wish to + accept. + + **Returns:** + + A PyTree of arbitrary Python objects. + """ + + @abc.abstractmethod + def compute( + self, state: _SolverState, vector: PyTree[Array], options: dict[str, Any] + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + """Solves a linear system. + + **Arguments:** + + - `state`: as returned from [`lineax.AbstractLinearSolver.init`][]. + - `vector`: the vector to solve against. + - `options`: a dictionary of any extra options that the solver may wish to + accept. For example, [`lineax.CG`][] accepts a `preconditioner` option. + + **Returns:** + + A 3-tuple of: + + - The solution to the linear system. + - An integer indicating the success or failure of the solve. This is an integer + which may be converted to a human-readable error message via + `lx.RESULTS[...]`. + - A dictionary of an extra statistics about the solve, e.g. the number of steps + taken. + """ + + @abc.abstractmethod + def transpose( + self, state: _SolverState, options: dict[str, Any] + ) -> tuple[_SolverState, dict[str, Any]]: + """Transposes the result of [`lineax.AbstractLinearSolver.init`][]. + + That is, it should be the case that + ```python + state_transpose, _ = solver.transpose(solver.init(operator, options), options) + state_transpose2 = solver.init(operator.T, options) + ``` + must be identical to each other. + + It is relatively common (in particular when differentiating through a linear + solve) to need to solve both `Ax = b` and `A^T x = b`. This method makes it + possible to avoid computing both `solver.init(operator)` and + `solver.init(operator.T)` if one can be cheaply computed from the other. + + **Arguments:** + + - `state`: as returned from `solver.init`. + - `options`: any extra options that were passed to `solve.init`. + + **Returns:** + + A 2-tuple of: + + - The state of the transposed operator. + - The options for the transposed operator. + """ + + @abc.abstractmethod + def conj( + self, state: _SolverState, options: dict[str, Any] + ) -> tuple[_SolverState, dict[str, Any]]: + """Conjugate the result of [`lineax.AbstractLinearSolver.init`][]. + + That is, it should be the case that + ```python + state_conj, _ = solver.conj(solver.init(operator, options), options) + state_conj2 = solver.init(conj(operator), options) + ``` + must be identical to each other. + + **Arguments:** + + - `state`: as returned from `solver.init`. + - `options`: any extra options that were passed to `solve.init`. + + **Returns:** + + A 2-tuple of: + + - The state of the conjugated operator. + - The options for the conjugated operator. + """ + + @abc.abstractmethod + def assume_full_rank(self) -> bool: + """Does this solver assume that all operators are full rank? + + When `False`, a more expensive backward pass is needed to account for + the extra generality. In a custom linear solver, it is always safe to + return False. + + **Arguments:** + + Nothing. + + **Returns:** + + Either `True` or `False`. + """ From f1b50027eb6ec0ad8ec64c70520c7929f0ecdd44 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Sun, 21 Jun 2026 03:00:51 +0100 Subject: [PATCH 3/4] Add `HEVD` solver, `hermitian_tag`, `is_hermitian` check and `.H` operator property` --- docs/api/functions.md | 4 + docs/api/operators.md | 1 + docs/api/solvers.md | 7 ++ docs/api/tags.md | 8 +- lineax/__init__.py | 3 + lineax/_operator/__init__.py | 1 + lineax/_operator/base.py | 54 +++++++++++ lineax/_operator/binary.py | 14 +++ lineax/_operator/core.py | 38 ++++---- lineax/_operator/structured.py | 9 ++ lineax/_operator/wrapper.py | 46 +++++++++ lineax/_solve.py | 167 +++++++++++++++++++++++++++++---- lineax/_solver/__init__.py | 1 + lineax/_solver/auto.py | 7 ++ lineax/_solver/hevd.py | 155 ++++++++++++++++++++++++++++++ lineax/_tags.py | 5 + tests/helpers.py | 20 +++- tests/test_max_rank.py | 84 +++++++++++++++++ tests/test_operator.py | 49 ++++++++++ tests/test_transpose.py | 35 +++++++ tests/test_well_posed.py | 2 +- 21 files changed, 672 insertions(+), 38 deletions(-) create mode 100644 lineax/_solver/hevd.py diff --git a/docs/api/functions.md b/docs/api/functions.md index 82c8f489..fbbfe046 100644 --- a/docs/api/functions.md +++ b/docs/api/functions.md @@ -56,4 +56,8 @@ Note that these do *not* inspect the values of the operator -- instead, they use --- +::: lineax.is_hermitian + +--- + ::: lineax.max_rank diff --git a/docs/api/operators.md b/docs/api/operators.md index 3180df18..a9245212 100644 --- a/docs/api/operators.md +++ b/docs/api/operators.md @@ -16,6 +16,7 @@ Or, perhaps we only have a function $F : \mathbb{R}^m \to \mathbb{R}^n$ such tha - mv - as_matrix - transpose + - H - in_structure - out_structure - in_size diff --git a/docs/api/solvers.md b/docs/api/solvers.md index b3159d94..d57ec6aa 100644 --- a/docs/api/solvers.md +++ b/docs/api/solvers.md @@ -44,6 +44,13 @@ These are capable of solving ill-posed linear problems. --- +::: lineax.HEVD + options: + members: + - __init__ + +--- + ::: lineax.Normal options: members: diff --git a/docs/api/tags.md b/docs/api/tags.md index 18f582e6..c26a8d4c 100644 --- a/docs/api/tags.md +++ b/docs/api/tags.md @@ -68,7 +68,13 @@ lx.is_positive_semidefinite(lx.IdentityLinearOperator(...)) # True ::: lineax.symmetric_tag -Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) +Marks that an operator is symmetric. (As a matrix, $A = A^\intercal$.) For real operators this coincides with `lineax.hermitian_tag`. + +--- + +::: lineax.hermitian_tag + +Marks that an operator is Hermitian (self-adjoint). (As a matrix, $A = A^*$, the conjugate transpose.) For real operators this coincides with `lineax.symmetric_tag`. --- diff --git a/lineax/__init__.py b/lineax/__init__.py index f7d0d603..10834001 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -27,6 +27,7 @@ has_unit_diagonal as has_unit_diagonal, IdentityLinearOperator as IdentityLinearOperator, is_diagonal as is_diagonal, + is_hermitian as is_hermitian, is_lower_triangular as is_lower_triangular, is_negative_semidefinite as is_negative_semidefinite, is_positive_semidefinite as is_positive_semidefinite, @@ -59,6 +60,7 @@ Cholesky as Cholesky, Diagonal as Diagonal, GMRES as GMRES, + HEVD as HEVD, LSMR as LSMR, LU as LU, Normal as Normal, @@ -70,6 +72,7 @@ ) from ._tags import ( diagonal_tag as diagonal_tag, + hermitian_tag as hermitian_tag, invert_tags as invert_tags, invert_tags_rules as invert_tags_rules, lower_triangular_tag as lower_triangular_tag, diff --git a/lineax/_operator/__init__.py b/lineax/_operator/__init__.py index 5a613809..b4493e61 100644 --- a/lineax/_operator/__init__.py +++ b/lineax/_operator/__init__.py @@ -18,6 +18,7 @@ diagonal as diagonal, has_unit_diagonal as has_unit_diagonal, is_diagonal as is_diagonal, + is_hermitian as is_hermitian, is_lower_triangular as is_lower_triangular, is_negative_semidefinite as is_negative_semidefinite, is_positive_semidefinite as is_positive_semidefinite, diff --git a/lineax/_operator/base.py b/lineax/_operator/base.py index f2d8593d..470aebf1 100644 --- a/lineax/_operator/base.py +++ b/lineax/_operator/base.py @@ -87,6 +87,7 @@ class AbstractLinearOperator(eqx.Module): def __check_init__(self): if ( is_symmetric(self) + or is_hermitian(self) or is_positive_semidefinite(self) or is_negative_semidefinite(self) ): @@ -197,6 +198,19 @@ def T(self) -> "AbstractLinearOperator": """Equivalent to [`lineax.AbstractLinearOperator.transpose`][]""" return self.transpose() + @property + def H(self) -> "AbstractLinearOperator": + """The conjugate transpose (Hermitian adjoint) `Aᴴ` of this operator. + + Equivalent to `lineax.conj(operator).transpose()`, except that for a Hermitian + operator -- for which `Aᴴ = A` -- this is a no-op and returns `self` unchanged. + (As with [`lineax.is_hermitian`][], only the tag is checked, not the actual + values of the operator.) + """ + if is_hermitian(self): + return self + return conj(self).transpose() + def __add__(self, other) -> "AbstractLinearOperator": # Local imports to avoid a circular dependency: `binary`/`wrapper` import # `base`, so the operators built here are imported lazily at call time. @@ -399,6 +413,20 @@ def tridiagonal( _default_not_implemented("tridiagonal", operator) +def has_real_dtype(operator) -> bool: + """Check if all dtypes in an operator's structure are real (not complex).""" + leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure())) + dtype = jnp.result_type(*leaves) + if jnp.issubdtype(dtype, jnp.complexfloating): + return False + elif jnp.issubdtype(dtype, jnp.floating): + return True + else: + assert False, ( + "Only `jnp.floating` and `jnp.complexfloating` dtypes are understood." + ) + + @ft.singledispatch def is_symmetric(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as symmetric. @@ -417,6 +445,32 @@ def is_symmetric(operator: AbstractLinearOperator) -> bool: _default_not_implemented("is_symmetric", operator) +@ft.singledispatch +def is_hermitian(operator: AbstractLinearOperator) -> bool: + """Returns whether an operator is marked as Hermitian (self-adjoint). + + See [the documentation on linear operator tags](../api/tags.md) for more + information. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + Either `True` or `False.` + """ + # Default for operators that don't register `is_hermitian` explicitly (e.g. custom + # `AbstractLinearOperator`s written before it existed): derive it from the other + # property checks, the same way the built-in operators do minus the + # `hermitian_tag` check, which needs operator-specific `.tags`. + if is_positive_semidefinite(operator) or is_negative_semidefinite(operator): + return True + if has_real_dtype(operator) and is_symmetric(operator): + return True + return False + + @ft.singledispatch def is_diagonal(operator: AbstractLinearOperator) -> bool: """Returns whether an operator is marked as diagonal. diff --git a/lineax/_operator/binary.py b/lineax/_operator/binary.py index 52ad764a..898c0742 100644 --- a/lineax/_operator/binary.py +++ b/lineax/_operator/binary.py @@ -23,8 +23,10 @@ AbstractLinearOperator, conj, diagonal, + has_real_dtype, has_unit_diagonal, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -203,6 +205,7 @@ def _(operator): for check in ( is_symmetric, + is_hermitian, is_diagonal, is_lower_triangular, is_upper_triangular, @@ -247,6 +250,17 @@ def _(operator): return is_diagonal(operator.operator1) and is_diagonal(operator.operator2) +# is_hermitian: as above, diagonal matrices commute. A product of diagonals is itself +# diagonal, which is Hermitian only when its (complex) entries are real-valued. +@is_hermitian.register(ComposedLinearOperator) +def _(operator): + return ( + is_diagonal(operator.operator1) + and is_diagonal(operator.operator2) + and has_real_dtype(operator) + ) + + # is_tridiagonal: tridiagonal @ tridiagonal = pentadiagonal, but # tridiagonal @ diagonal = tridiagonal and diagonal @ tridiagonal = tridiagonal @is_tridiagonal.register(ComposedLinearOperator) diff --git a/lineax/_operator/core.py b/lineax/_operator/core.py index 20673a28..df9e931c 100644 --- a/lineax/_operator/core.py +++ b/lineax/_operator/core.py @@ -38,6 +38,7 @@ ) from .._tags import ( diagonal_tag, + hermitian_tag, lower_triangular_tag, negative_semidefinite_tag, positive_semidefinite_tag, @@ -53,9 +54,11 @@ conj, diagonal, FlatPyTree, + has_real_dtype, has_unit_diagonal, inexact_structure, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -724,20 +727,6 @@ def _(operator): # checks -def _has_real_dtype(operator) -> bool: - """Check if all dtypes in an operator's structure are real (not complex).""" - leaves = jtu.tree_leaves((operator.in_structure(), operator.out_structure())) - dtype = jnp.result_type(*leaves) - if jnp.issubdtype(dtype, jnp.complexfloating): - return False - elif jnp.issubdtype(dtype, jnp.floating): - return True - else: - assert False, ( - "Only `jnp.floating` and `jnp.complexfloating` dtypes are understood." - ) - - @is_symmetric.register(MatrixLinearOperator) @is_symmetric.register(PyTreeLinearOperator) @is_symmetric.register(JacobianLinearOperator) @@ -746,12 +735,29 @@ def _(operator): # Symmetric (A = A^T) if explicitly tagged symmetric or diagonal if symmetric_tag in operator.tags or diagonal_tag in operator.tags: return True - # PSD/NSD implies symmetric only for real dtypes; for complex, it's Hermitian + # PSD/NSD/Hermitian imply A = A^T only for real dtypes if ( positive_semidefinite_tag in operator.tags or negative_semidefinite_tag in operator.tags + or hermitian_tag in operator.tags + ): + return has_real_dtype(operator) + return False + + +@is_hermitian.register(MatrixLinearOperator) +@is_hermitian.register(PyTreeLinearOperator) +@is_hermitian.register(JacobianLinearOperator) +@is_hermitian.register(FunctionLinearOperator) +def _(operator): + if ( + hermitian_tag in operator.tags + or positive_semidefinite_tag in operator.tags + or negative_semidefinite_tag in operator.tags ): - return _has_real_dtype(operator) + return True + if symmetric_tag in operator.tags or diagonal_tag in operator.tags: + return has_real_dtype(operator) return False diff --git a/lineax/_operator/structured.py b/lineax/_operator/structured.py index f0195c63..9a0f0c20 100644 --- a/lineax/_operator/structured.py +++ b/lineax/_operator/structured.py @@ -40,9 +40,11 @@ conj, diagonal, FlatPyTree, + has_real_dtype, has_unit_diagonal, inexact_structure, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -292,6 +294,7 @@ def _(operator): @is_symmetric.register(IdentityLinearOperator) +@is_hermitian.register(IdentityLinearOperator) def _(operator): return eqx.tree_equal(operator.in_structure(), operator.out_structure()) is True @@ -301,7 +304,13 @@ def _(operator): return True +@is_hermitian.register(DiagonalLinearOperator) +def _(operator): + return has_real_dtype(operator) + + @is_symmetric.register(TridiagonalLinearOperator) +@is_hermitian.register(TridiagonalLinearOperator) def _(operator): return False diff --git a/lineax/_operator/wrapper.py b/lineax/_operator/wrapper.py index e9d7576e..fd28052e 100644 --- a/lineax/_operator/wrapper.py +++ b/lineax/_operator/wrapper.py @@ -28,6 +28,7 @@ from .._tags import ( diagonal_tag, + hermitian_tag, lower_triangular_tag, MaxRankTag, negative_semidefinite_tag, @@ -43,8 +44,10 @@ as_frozenset, conj, diagonal, + has_real_dtype, has_unit_diagonal, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -340,6 +343,7 @@ def _(operator): for check in ( is_symmetric, + is_hermitian, is_diagonal, has_unit_diagonal, is_lower_triangular, @@ -371,6 +375,28 @@ def _(operator, check=check): return check(operator.operator) +def _scalar_is_real(scalar) -> bool: + """Whether a scalar is statically known to be real-valued. + + A real dtype guarantees a real value, so this is known even for JAX tracers (whose + runtime value is unknown at trace time): only the dtype matters, not the value. + Returns `False` only for genuinely complex-typed scalars. + """ + return not jnp.issubdtype(jnp.result_type(scalar), jnp.complexfloating) + + +# Hermitian-ness preserved by negation and scaling by any real scalar +@is_hermitian.register(NegLinearOperator) +def _(operator): + return is_hermitian(operator.operator) + + +@is_hermitian.register(MulLinearOperator) +@is_hermitian.register(DivLinearOperator) +def _(operator): + return _scalar_is_real(operator.scalar) and is_hermitian(operator.operator) + + # has_unit_diagonal is NOT preserved by negation @has_unit_diagonal.register(NegLinearOperator) def _(operator): @@ -498,6 +524,26 @@ def _(operator, check=check, tag=tag): return (tag in operator.tags) or check(operator.operator) +# `is_hermitian` is special-cased rather than handled by the loop above: a tag other +# than `hermitian_tag` can still imply Hermitian-ness. PSD/NSD operators are Hermitian +# (real or complex), and real symmetric/diagonal operators are Hermitian too. This +# mirrors the cross-implications encoded for the core operators. +@is_hermitian.register(TaggedLinearOperator) +def _(operator): + tags = operator.tags + if is_hermitian(operator.operator): + return True + if ( + hermitian_tag in tags + or positive_semidefinite_tag in tags + or negative_semidefinite_tag in tags + ): + return True + if symmetric_tag in tags or diagonal_tag in tags: + return has_real_dtype(operator) + return False + + @max_rank.register(TaggedLinearOperator) def _(operator): inner = max_rank(operator.operator) diff --git a/lineax/_solve.py b/lineax/_solve.py index cd940b00..a33e18e1 100644 --- a/lineax/_solve.py +++ b/lineax/_solve.py @@ -13,7 +13,7 @@ # limitations under the License. import functools as ft -from typing import Any +from typing import Any, TypeAlias import equinox as eqx import equinox.internal as eqxi @@ -31,18 +31,28 @@ from ._misc import inexact_asarray, strip_weak_dtype from ._operator import ( AbstractLinearOperator, - conj, FunctionLinearOperator, IdentityLinearOperator, + is_hermitian, linearise, max_rank, + TaggedLinearOperator, TangentLinearOperator, ) from ._solution import RESULTS, Solution -from ._solver import AutoLinearSolver +from ._solver import ( + AutoLinearSolver as AutoLinearSolver, + Cholesky, + HEVD, + Normal, + QR, + SVD, +) from ._solver.base import AbstractLinearSolver as AbstractLinearSolver +from ._solver.misc import pack_structures from ._tags import ( invert_tags, + positive_semidefinite_tag, tags_from_checks, ) @@ -137,6 +147,95 @@ def _linear_solve_abstract_eval(operator, state, vector, options, solver, throw) return out +def _squared_rcond(rcond: float | None, n: int, m: int, dtype) -> float: + """`resolve_rcond(rcond, n, m, dtype) ** 2`, as a Python `float`. + + A pure-Python mirror of [`resolve_rcond`][] (no `jnp.where`), so the result stays a + Python scalar -- HEVD's `rcond` is `float | None`, and under the recursive gram + solve `resolve_rcond` of a non-`None` rcond would otherwise produce a traced array. + + Squaring reproduces the original solver's rank cutoff on the gram's eigenvalues + `σ²`, since `σ² > rcond²·σ²ₘₐₓ <=> σ > rcond·σₘₐₓ`. + """ + eps = float(jnp.finfo(dtype).eps) + if rcond is None: + rcond = 2 * eps * max(n, m) + elif rcond < 0: + rcond = eps + return float(rcond) ** 2 + + +# Solver types whose factorisation *may* cheaply yield the gram (pseudo)inverse +# `(AᴴA)⁺`. Whether one actually does can depend on the state (e.g. `Normal` only when +# tall), so `_has_gram_partner` is the definitive runtime check; see `_gram_partner`. +_MaybeHasGramPartner: TypeAlias = QR | SVD | HEVD | Normal + + +def _has_gram_partner(solver: AbstractLinearSolver, state: Any) -> bool: + """Can the (pseudo)inverse for A^H A actually be inferred from solver's state?""" + if isinstance(solver, Normal): + _, tall, _, _ = state + return tall.value # inner operator is `AᴴA`; when wide it is `AAᴴ` + return isinstance(solver, _MaybeHasGramPartner) + + +def _gram_partner( + solver: _MaybeHasGramPartner, + gram_operator: AbstractLinearOperator, + state: Any, +) -> tuple[AbstractLinearSolver, Any]: + """Return a `(gram_solver, gram_state)` pair such that + `linear_solve_p(gram_operator, gram_state, v, gram_solver)` computes `(AᴴA)⁺ v` + (where `gram_operator` is `AᴴA`). Requires `_has_gram_partner(solver, state)`. + + Each candidate solver's gram partner -- a solver representing the (pseudo)inverse of + the gram matrix `AᴴA` -- is obtained from the existing factorisation with no further + decomposition: + + QR `A = QR` -> `Cholesky`, since `AᴴA = RᴴR` (`R` is the factor) + SVD `A = UΣVᴴ` -> `HEVD` with eigenvectors `V`, eigenvalues `σ²` + HEVD `A = VWVᴴ` -> `HEVD` with eigenvectors `V`, eigenvalues `w²` + Normal (tall) -> its inner solver, which already factorises `AᴴA` + + The JVP uses this to collapse the two nested solves against `Aᴴ` (the inner adjoint + solve and the outer `A⁺`) into one gram solve. Routing it back through + `linear_solve_p` -- rather than applying the factors directly -- keeps it correct + under higher-order autodiff, since the gram solve then uses lineax's + pseudoinverse-aware adjoint rather than differentiating through the factorisation. + """ + if isinstance(solver, Normal): + inner_state, tall, _, _ = state + if not tall.value: + # Wide: the inner solver factorises `AAᴴ`, not `AᴴA`. `_has_gram_partner` + # excludes this, so reaching here is a caller bug. + raise ValueError("`Normal` has a gram partner only for tall operators") + # Tall: the inner solver already factorises `AᴴA`, so its state *is* the gram + # state. This holds for any inner solver (Cholesky, CG, HEVD, ...). + return solver.inner_solver, inner_state + if isinstance(solver, QR): + (a, _), transpose, _ = state + if transpose.value: + # QR is full rank, so the JVP reaches the gram path only when + # `rows > columns` (tall), where the stored factorisation is of `A`. + raise ValueError("`QR` has a gram partner only for tall operators") + # Tall `A = QR` => `AᴴA = RᴴR`: the QR factor `R` is the upper Cholesky factor. + r = a[: a.shape[1]] + return Cholesky(), (r, eqxi.Static(False)) + packed = pack_structures(gram_operator) + if isinstance(solver, SVD): + (u, s, vt), _ = state + # `(AᴴA)⁺ = V Σ⁻² Vᴴ`. + eigenvalues, eigenvectors = s**2, vt.conj().T + rcond = _squared_rcond(solver.rcond, vt.shape[1], u.shape[0], s.dtype) + else: + (w, eigenvectors), _ = state + # `(AᴴA)⁺ = (A²)⁺ = V W⁻² Vᴴ`. + eigenvalues = w**2 + m = eigenvectors.shape[0] + rcond = _squared_rcond(solver.rcond, m, m, w.dtype) + return HEVD(rcond=rcond), ((eigenvalues, eigenvectors), packed) + + @eqxi.filter_primitive_jvp def _linear_solve_jvp(primals, tangents): operator, state, vector, options, solver, throw = primals @@ -203,25 +302,55 @@ def _linear_solve_jvp(primals, tangents): assume_independent_rows = solver.assume_full_rank() and rows <= columns assume_independent_columns = solver.assume_full_rank() and columns <= rows if not assume_independent_rows or not assume_independent_columns: - operator_conj_transpose = conj(operator).transpose() - t_operator_conj_transpose = conj(t_operator).transpose() - state_conj, options_conj = solver.conj(state, options) - state_conj_transpose, options_conj_transpose = solver.transpose( - state_conj, options_conj - ) + operator_conj_transpose = operator.H + t_operator_conj_transpose = t_operator.H + if is_hermitian(operator): + # `Aᴴ = A`, so `init(Aᴴ) == init(A)`: the existing state already serves + # as the adjoint state. This holds for any solver, so the fast path is + # keyed on the operator rather than on the solver. + state_conj_transpose, options_conj_transpose = state, options + else: + state_conj, options_conj = solver.conj(state, options) + state_conj_transpose, options_conj_transpose = solver.transpose( + state_conj, options_conj + ) if not assume_independent_rows: lst_sqr_diff = (vector**ω - operator.mv(solution) ** ω).ω tmp = t_operator_conj_transpose.mv(lst_sqr_diff) # pyright: ignore - tmp, _, _ = eqxi.filter_primitive_bind( - linear_solve_p, - operator_conj_transpose, # pyright: ignore - state_conj_transpose, # pyright: ignore - tmp, - options_conj_transpose, # pyright: ignore - solver, - True, - ) - vecs.append(tmp) + # This term is `A⁺ (Aᴴ)⁺ w = (AᴴA)⁺ w`. If the solver has a gram partner, + # compute `(AᴴA)⁺ w` in a single gram solve against `AᴴA`; otherwise fall + # back to the generic nested adjoint solve (whose result is later + # left-multiplied by `A⁺` along with the other `vecs`). The gram operator + # is never materialised -- the gram solve reads only `gram_state` -- but it + # carries the right structure and (for higher-order autodiff) tangent. + if _has_gram_partner(solver, state): + gram_operator = TaggedLinearOperator( + operator.H @ operator, positive_semidefinite_tag + ) + gram_solver, gram_state = _gram_partner(solver, gram_operator, state) + gram_inv, _, _ = eqxi.filter_primitive_bind( + linear_solve_p, + gram_operator, + gram_state, + tmp, + {}, + gram_solver, + True, + ) + # `(AᴴA)⁺ w` already lives in the input space, so it bypasses the + # outer `A⁺`: append directly to the already-solved `sols`. + sols.append(gram_inv) + else: + tmp, _, _ = eqxi.filter_primitive_bind( + linear_solve_p, + operator_conj_transpose, # pyright: ignore + state_conj_transpose, # pyright: ignore + tmp, + options_conj_transpose, # pyright: ignore + solver, + True, + ) + vecs.append(tmp) if not assume_independent_columns: tmp1, _, _ = eqxi.filter_primitive_bind( diff --git a/lineax/_solver/__init__.py b/lineax/_solver/__init__.py index cbde8261..04683269 100644 --- a/lineax/_solver/__init__.py +++ b/lineax/_solver/__init__.py @@ -18,6 +18,7 @@ from .cholesky import Cholesky as Cholesky from .diagonal import Diagonal as Diagonal from .gmres import GMRES as GMRES +from .hevd import HEVD as HEVD from .lsmr import LSMR as LSMR from .lu import LU as LU from .normal import Normal as Normal diff --git a/lineax/_solver/auto.py b/lineax/_solver/auto.py index 9837db22..4e6f4fce 100644 --- a/lineax/_solver/auto.py +++ b/lineax/_solver/auto.py @@ -19,6 +19,7 @@ from .._operator import ( AbstractLinearOperator, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -29,6 +30,7 @@ from .base import AbstractLinearSolver from .cholesky import Cholesky from .diagonal import Diagonal +from .hevd import HEVD from .lu import LU from .qr import QR from .svd import SVD @@ -56,6 +58,7 @@ class AutoLinearSolver(AbstractLinearSolver[_AutoLinearSolverState]): - If `well_posed=False`: - If the operator is diagonal, then use [`lineax.Diagonal`][]. + - If the operator is Hermitian, then use [`lineax.HEVD`][]. - Else use [`lineax.SVD`][]. This is a good choice if you want to be certain that you can handle ill-posed @@ -101,6 +104,10 @@ def _select_solver(self, operator: AbstractLinearOperator) -> AbstractLinearSolv elif self.well_posed is False: if is_diagonal(operator): solver = Diagonal() + elif is_hermitian(operator): + # A Hermitian eigendecomposition is cheaper than a general SVD, and + # handles ill-posed Hermitian systems via the same pseudoinverse. + solver = HEVD() else: # TODO: use rank-revealing QR instead. solver = SVD() diff --git a/lineax/_solver/hevd.py b/lineax/_solver/hevd.py new file mode 100644 index 00000000..7a5492ba --- /dev/null +++ b/lineax/_solver/hevd.py @@ -0,0 +1,155 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeAlias + +import equinox as eqx +import jax.lax as lax +import jax.numpy as jnp +from jaxtyping import Array, PyTree + +from .._misc import resolve_rcond +from .._operator import ( + AbstractLinearOperator, + is_hermitian, + is_negative_semidefinite, + is_positive_semidefinite, + max_rank, +) +from .._solution import RESULTS +from .base import AbstractLinearSolver +from .misc import ( + pack_structures, + PackedStructures, + ravel_vector, + unravel_solution, +) + + +_HEVDState: TypeAlias = tuple[tuple[Array, Array], PackedStructures] + + +class HEVD(AbstractLinearSolver[_HEVDState]): + """Eigenvalue decomposition solver for Hermitian linear systems. + + The operator must be square and Hermitian (self-adjoint), but need not be + definite or even nonsingular: in the singular case this solver returns + the pseudoinverse solution. This is the optimised Hermitian analogue of + [`lineax.SVD`][]. + """ + + rcond: float | None = None + + def init(self, operator: AbstractLinearOperator, options: dict[str, Any]): + del options + if not is_hermitian(operator): + raise ValueError( + "`HEVD` may only be used for linear solves with Hermitian matrices." + ) + # `jnp.linalg.eigh` returns eigenvalues in ascending (signed) order + w, v = jnp.linalg.eigh(operator.as_matrix()) + r = max_rank(operator) + if r < w.shape[0]: + # The operator is declared to have rank at most `r`, so all but the `r` + # largest-magnitude eigenvalues are mathematically zero. Statically drop + # them to shrink the matmuls (and storage) in `compute`. + m = v.shape[0] + rcond = resolve_rcond(self.rcond, m, m, w.dtype) * jnp.max(jnp.abs(w)) + if is_positive_semidefinite(operator): + # Eigenvalues are >= 0, so in ascending order the `r` largest are a + # contiguous trailing slice (cheaper than a reordering gather). + w, v, dropped = w[m - r :], v[:, m - r :], w[: m - r] + elif is_negative_semidefinite(operator): + # Eigenvalues are <= 0, so the `r` largest in magnitude are a + # contiguous leading slice. + w, v, dropped = w[:r], v[:, :r], w[r:] + else: + # Indefinite: the small-magnitude eigenvalues sit in the interior of + # the spectrum, so no contiguous slice works. Reorder by descending + # magnitude (an O(n^2) gather, dominated by the O(n^3) eigensolve) + # and take the leading `r`. + order = jnp.argsort(jnp.abs(w))[::-1] + w, v = w[order], v[:, order] + w, v, dropped = w[:r], v[:, :r], w[r:] + # `compute` masks out `|w_i| <= rcond * max|w|`, so dropping these is + # lossless iff they all sit below that floor. Otherwise the `max_rank` + # claim is false (truncation would change the solution), so error out. + # Checking the largest discarded magnitude also catches a mistagged + # PSD/NSD operator whose true large eigenvalues sit on the dropped side. + w = eqx.error_if( + w, + jnp.max(jnp.abs(dropped)) > rcond, + "lineax.HEVD: the operator was declared (via a `MaxRankTag`, or by " + f"composition rules) to have rank at most {r}, but it has an " + "eigenvalue above the rcond threshold beyond that rank. Truncating to " + "the declared rank would change the solution, so the rank claim " + "appears to be incorrect. Remove/loosen the rank tag, increase " + "`rcond` if you intend a low-rank approximation, or set " + "`EQX_ON_ERROR=off` to skip this check.", + ) + packed_structures = pack_structures(operator) + return (w, v), packed_structures + + def compute( + self, + state: _HEVDState, + vector: PyTree[Array], + options: dict[str, Any], + ) -> tuple[PyTree[Array], RESULTS, dict[str, Any]]: + del options + (w, v), packed_structures = state + vector = ravel_vector(vector, packed_structures) + m = v.shape[0] + rcond = resolve_rcond(self.rcond, m, m, w.dtype) + rcond = jnp.array(rcond, dtype=w.dtype) + abs_w = jnp.abs(w) + if w.size > 0: + # Eigenvalues are signed and not magnitude-sorted; scale by max |w|. + rcond = rcond * jnp.max(abs_w) + # Not >=, or this fails with a matrix of all-zeros. + mask = abs_w > rcond + rank = mask.sum() + safe_w = jnp.where(mask, w, 1) + w_inv = jnp.where(mask, jnp.array(1.0) / safe_w, 0).astype(v.dtype) + vHb = jnp.matmul(v.conj().T, vector, precision=lax.Precision.HIGHEST) + solution = jnp.matmul(v, w_inv * vHb, precision=lax.Precision.HIGHEST) + solution = unravel_solution(solution, packed_structures) + return solution, RESULTS.successful, {"rank": rank} + + def transpose(self, state: _HEVDState, options: dict[str, Any]): + del options + (w, v), packed_structures = state + # `A` is Hermitian, so `A^T = conj(A)` (with `w` real). The structure is + # square symmetric, so the packed structures are unchanged. + transpose_state = (w, v.conj()), packed_structures + transpose_options = {} + return transpose_state, transpose_options + + def conj(self, state: _HEVDState, options: dict[str, Any]): + del options + (w, v), packed_structures = state + # `A` is Hermitian, so `conj(A) = conj(V) diag(w) conj(V)^H` (with `w` real). + conj_state = (w, v.conj()), packed_structures + conj_options = {} + return conj_state, conj_options + + def assume_full_rank(self): + return False + + +HEVD.__init__.__doc__ = """**Arguments**: + +- `rcond`: the cutoff for handling zero entries on the diagonal. Defaults to machine + precision times `N`, where `(N, N)` is the shape of the operator. +""" diff --git a/lineax/_tags.py b/lineax/_tags.py index 83051a8a..b8d6aa8e 100644 --- a/lineax/_tags.py +++ b/lineax/_tags.py @@ -29,6 +29,7 @@ def __repr__(self): symmetric_tag = _HasRepr("symmetric_tag") +hermitian_tag = _HasRepr("hermitian_tag") diagonal_tag = _HasRepr("diagonal_tag") tridiagonal_tag = _HasRepr("tridiagonal_tag") unit_diagonal_tag = _HasRepr("unit_diagonal_tag") @@ -127,6 +128,7 @@ def tags_from_checks(operator: "AbstractLinearOperator") -> frozenset[object]: from ._operator import ( has_unit_diagonal, is_diagonal, + is_hermitian, is_lower_triangular, is_negative_semidefinite, is_positive_semidefinite, @@ -140,6 +142,7 @@ def tags_from_checks(operator: "AbstractLinearOperator") -> frozenset[object]: tag for check, tag in [ (is_symmetric, symmetric_tag), + (is_hermitian, hermitian_tag), (is_diagonal, diagonal_tag), (is_lower_triangular, lower_triangular_tag), (is_upper_triangular, upper_triangular_tag), @@ -163,6 +166,7 @@ def tags_from_checks(operator: "AbstractLinearOperator") -> frozenset[object]: for tag in ( symmetric_tag, + hermitian_tag, unit_diagonal_tag, diagonal_tag, positive_semidefinite_tag, @@ -227,6 +231,7 @@ def transpose_tags(tags: frozenset[object]): for tag in ( symmetric_tag, + hermitian_tag, diagonal_tag, lower_triangular_tag, upper_triangular_tag, diff --git a/tests/helpers.py b/tests/helpers.py index 3194d11a..8b4a8d7d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -49,6 +49,8 @@ def _construct_matrix_impl( matrix = jnp.diag(jnp.diag(matrix)) if has_tag(tags, lx.symmetric_tag): matrix = matrix + matrix.T + if has_tag(tags, lx.hermitian_tag): + matrix = matrix + matrix.conj().T if has_tag(tags, lx.lower_triangular_tag): matrix = jnp.tril(matrix) if has_tag(tags, lx.upper_triangular_tag): @@ -64,6 +66,15 @@ def _construct_matrix_impl( matrix = matrix @ matrix.T.conj() if has_tag(tags, lx.negative_semidefinite_tag): matrix = -matrix @ matrix.T.conj() + if cond_or_singular == "zero" and ( + has_tag(tags, lx.symmetric_tag) or has_tag(tags, lx.hermitian_tag) + ): + # The symmetric/Hermitian construction refills the leading row that + # `zero` cleared, so re-zero the leading row *and* column of the result. + # This makes `e_0` a null vector -- a genuinely rank-deficient, and still + # indefinite, Hermitian operator -- mirroring how `zero` yields a + # rank-deficient matrix for the PSD/NSD constructions. + matrix = matrix.at[0, :].set(0).at[:, 0].set(0) if isinstance(cond_or_singular, str): break else: @@ -84,7 +95,10 @@ def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64): - if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)): + if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES, lx.HEVD)): + # `trim_row`/`trim_col` produce non-square matrices, which are incompatible + # with the (square) structure these solvers require. Use `zero` instead, + # which keeps the matrix square and (for PSD/NSD/Hermitian tags) Hermitian. singular_method = "zero" else: # Use `getkey()` rather than the stdlib `random.choice` for reproducibility @@ -133,6 +147,10 @@ def construct_poisson_matrix(size, dtype=jnp.float64): (lx.Cholesky(), lx.positive_semidefinite_tag, False), (lx.Cholesky(), lx.negative_semidefinite_tag, False), (lx.Normal(lx.Cholesky()), (), False), + (lx.HEVD(), lx.positive_semidefinite_tag, True), + (lx.HEVD(), lx.negative_semidefinite_tag, True), + (lx.HEVD(), lx.hermitian_tag, True), + (lx.Normal(lx.HEVD()), (), True), ] solvers_tags = [(a, b) for a, b, _ in solvers_tags_pseudoinverse] solvers = [a for a, _, _ in solvers_tags_pseudoinverse] diff --git a/tests/test_max_rank.py b/tests/test_max_rank.py index a867547a..8c0eb9ee 100644 --- a/tests/test_max_rank.py +++ b/tests/test_max_rank.py @@ -256,3 +256,87 @@ def test_svd_raises_when_max_rank_too_small(): vector = jnp.arange(10.0) + 0.5 with pytest.raises(eqx.EquinoxRuntimeError): lx.linear_solve(operator, vector, lx.SVD()) + + +# --------------------------------------------------------------------------- +# HEVD truncation +# --------------------------------------------------------------------------- + + +def _hermitian_with_spectrum(key, eigvals): + # `Q diag(eigvals) Q^T` with `Q` orthogonal: a (real-)Hermitian matrix with the + # given eigenvalues. Zeros placed between nonzero eigenvalues of both signs land + # in the interior of eigh's ascending order, exercising the fact that HEVD must + # truncate by *magnitude* (unlike SVD, where the small values are a contiguous + # tail). + size = len(eigvals) + q, _ = jnp.linalg.qr(jax.random.normal(key, (size, size))) + d = jnp.asarray(eigvals, dtype=q.dtype) + return (q * d[None, :]) @ q.T + + +def test_hevd_truncates_state_to_max_rank(): + # A genuinely rank-2 indefinite Hermitian matrix, declared rank 2: HEVD's state is + # truncated to 2 (eigenvalue, eigenvector) pairs and the solution is unchanged. + matrix = _hermitian_with_spectrum(jax.random.PRNGKey(0), [3.0, -2.0, 0.0, 0.0, 0.0]) + solver = lx.HEVD() + + plain = lx.MatrixLinearOperator(matrix, lx.hermitian_tag) + (w_full, v_full), _ = solver.init(plain, {}) + assert w_full.shape == (5,) + assert v_full.shape == (5, 5) + + tagged = lx.MatrixLinearOperator(matrix, (lx.hermitian_tag, lx.MaxRankTag(2))) + (w_t, v_t), _ = solver.init(tagged, {}) + assert w_t.shape == (2,) + assert v_t.shape == (5, 2) + + vector = jnp.arange(5.0) + 0.5 + x_plain = lx.linear_solve(plain, vector, solver).value + x_tagged = lx.linear_solve(tagged, vector, solver).value + assert jnp.allclose(x_plain, x_tagged) + + +# PSD (eigenvalues >= 0) and NSD (<= 0) operators have their near-zero eigenvalues at +# a contiguous *end* of eigh's ascending order, so truncation is a slice rather than a +# reorder. Indefinite operators need the reordering gather. Cover all three branches. +@pytest.mark.parametrize( + "tag, eigvals", + ( + (lx.hermitian_tag, [3.0, -2.0, 0.0, 0.0, 0.0]), # indefinite -> reorder + (lx.positive_semidefinite_tag, [3.0, 2.0, 0.0, 0.0, 0.0]), # PSD -> slice tail + (lx.negative_semidefinite_tag, [-3.0, -2.0, 0.0, 0.0, 0.0]), # NSD -> head + ), +) +def test_hevd_truncation_branches(tag, eigvals): + matrix = _hermitian_with_spectrum(jax.random.PRNGKey(0), eigvals) + solver = lx.HEVD() + plain = lx.MatrixLinearOperator(matrix, tag) + tagged = lx.MatrixLinearOperator(matrix, (tag, lx.MaxRankTag(2))) + + (w_t, v_t), _ = solver.init(tagged, {}) + assert w_t.shape == (2,) + assert v_t.shape == (5, 2) + + vector = jnp.arange(5.0) + 0.5 + x_plain = lx.linear_solve(plain, vector, solver).value + x_tagged = lx.linear_solve(tagged, vector, solver).value + assert jnp.allclose(x_plain, x_tagged) + + +@pytest.mark.parametrize( + "tag, eigvals", + ( + (lx.hermitian_tag, [3.0, -2.0, 1.5, 0.0, 0.0]), + (lx.positive_semidefinite_tag, [3.0, 2.0, 1.5, 0.0, 0.0]), + (lx.negative_semidefinite_tag, [-3.0, -2.0, -1.5, 0.0, 0.0]), + ), +) +def test_hevd_raises_when_max_rank_too_small(tag, eigvals): + # A genuinely rank-3 Hermitian matrix declared rank 2: truncation would discard an + # eigenvalue above the rcond threshold, so the solve must raise. + matrix = _hermitian_with_spectrum(jax.random.PRNGKey(0), eigvals) + operator = lx.MatrixLinearOperator(matrix, (tag, lx.MaxRankTag(2))) + vector = jnp.arange(5.0) + 0.5 + with pytest.raises(eqx.EquinoxRuntimeError): + lx.linear_solve(operator, vector, lx.HEVD()) diff --git a/tests/test_operator.py b/tests/test_operator.py index d3c5eff1..a4dc81db 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -249,6 +249,55 @@ def test_is_symmetric(dtype, getkey): _assert_except_diag(lx.is_symmetric, not_symmetric_operators, flip_cond=True) +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_is_hermitian(dtype, getkey): + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + hermitian_operators = _setup(getkey, matrix + matrix.conj().T, lx.hermitian_tag) + for operator in hermitian_operators: + assert lx.is_hermitian(operator) + + not_hermitian_operators = _setup(getkey, matrix) + _assert_except_diag(lx.is_hermitian, not_hermitian_operators, flip_cond=True) + + +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_is_hermitian_implications(dtype, getkey): + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + real = jnp.issubdtype(dtype, jnp.floating) + + # PSD/NSD are Hermitian for both real and complex dtypes. + psd = matrix @ matrix.conj().T + assert lx.is_hermitian(lx.MatrixLinearOperator(psd, lx.positive_semidefinite_tag)) + assert lx.is_hermitian(lx.MatrixLinearOperator(-psd, lx.negative_semidefinite_tag)) + + # Symmetric (A = Aᵀ) and diagonal operators are Hermitian iff real-valued. + sym = lx.MatrixLinearOperator(matrix + matrix.T, lx.symmetric_tag) + assert lx.is_hermitian(sym) == real + assert lx.is_hermitian(lx.DiagonalLinearOperator(jnp.diag(matrix))) == real + + # Conversely a Hermitian operator is symmetric iff real-valued. + herm = lx.MatrixLinearOperator(matrix + matrix.conj().T, lx.hermitian_tag) + assert lx.is_hermitian(herm) + assert lx.is_symmetric(herm) == real + + +def test_hermitian_tag_propagation(getkey): + # Hermitian-ness is preserved through transpose and inversion, addition, and real + # (but not complex) scaling. + assert lx.hermitian_tag in lx.transpose_tags(frozenset({lx.hermitian_tag})) + assert lx.hermitian_tag in lx.invert_tags(frozenset({lx.hermitian_tag})) + + def herm_op(): + m = jr.normal(getkey(), (3, 3), dtype=jnp.complex128) + return lx.MatrixLinearOperator(m + m.conj().T, lx.hermitian_tag) + + op = herm_op() + assert lx.is_hermitian(op + herm_op()) # sum of Hermitian is Hermitian + assert lx.is_hermitian(-op) # negation preserves + assert lx.is_hermitian(op * 2.0) # real scaling preserves + assert not lx.is_hermitian(op * (1.0 + 1j)) # complex scaling does not + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_is_diagonal(dtype, getkey): matrix = jr.normal(getkey(), (3, 3), dtype=dtype) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 51b5821a..c858762c 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -60,3 +60,38 @@ def test_pytree_transpose(_, assert_transpose_fixture): # pyright: ignore in_vec = [a(1.0), 2.0, 3.0] solver = lx.AutoLinearSolver(well_posed=False) assert_transpose_fixture(operator, out_vec, in_vec, solver) + + +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_H_property(getkey, dtype): + # `.H` is the conjugate transpose `conj(A)ᵀ`, but a no-op for Hermitian operators. + m = jr.normal(getkey(), (4, 3), dtype=dtype) + op = lx.MatrixLinearOperator(m) + assert op.H is not op + assert tree_allclose(op.H.as_matrix(), m.conj().T) + + herm = jr.normal(getkey(), (3, 3), dtype=dtype) + herm = herm + herm.conj().T + hop = lx.MatrixLinearOperator(herm, lx.hermitian_tag) + assert hop.H is hop + + +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +@pytest.mark.parametrize("solver_cls", (lx.HEVD, lx.Cholesky, lx.LU)) +def test_hermitian_adjoint_state_reuses_init(getkey, dtype, solver_cls): + # For a Hermitian operator (`Aᴴ = A`), `init(A.H) == init(A)`, so the linear-solve + # JVP reuses the original state as the adjoint state -- this holds for any solver. + m = jr.normal(getkey(), (4, 4), dtype=dtype) + m = m + m.conj().T + if solver_cls is lx.Cholesky: + m = m @ m.conj().T # PSD, still Hermitian + op = lx.MatrixLinearOperator(m, lx.positive_semidefinite_tag) + else: + op = lx.MatrixLinearOperator(m, lx.hermitian_tag) + solver = solver_cls() + assert op.H is op + state = solver.init(op, {}) + adjoint_state = solver.init(op.H, {}) + assert tree_allclose( + eqx.filter(state, eqx.is_array), eqx.filter(adjoint_state, eqx.is_array) + ) diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index 4686ce86..5a5e7008 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -57,7 +57,7 @@ def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): def test_pytree_wellposed(solver, getkey, dtype): if not isinstance( solver, - (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.CG), + (lx.Diagonal, lx.Triangular, lx.Tridiagonal, lx.Cholesky, lx.HEVD, lx.CG), ): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 From f866b1747b20b806b7813b0c8ac6b5ca439512a5 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Mon, 22 Jun 2026 20:02:39 +0100 Subject: [PATCH 4/4] materialise operator for Normal(HEVD) --- lineax/_solver/normal.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lineax/_solver/normal.py b/lineax/_solver/normal.py index 021065c5..a1d205e6 100644 --- a/lineax/_solver/normal.py +++ b/lineax/_solver/normal.py @@ -29,6 +29,7 @@ from .._tags import positive_semidefinite_tag from .base import AbstractLinearSolver from .cholesky import Cholesky +from .hevd import HEVD _InnerSolverState = TypeVar("_InnerSolverState") @@ -113,8 +114,8 @@ def init(self, operator, options): # Cholesky materialises op twice when computing (op^H @ op).as_matrix() # Cheaper to materialise first and then conjugate-transpose. # For iterative solvers we only linearise to avoid eager materialisation. - is_cholesky = isinstance(self.inner_solver, Cholesky) - lin_op = materialise(operator) if is_cholesky else linearise(operator) + is_direct = isinstance(self.inner_solver, Cholesky | HEVD) + lin_op = materialise(operator) if is_direct else linearise(operator) if tall: inner_operator = conj(lin_op.transpose()) @ lin_op else: